├── .gitignore ├── README.md ├── annotation ├── c3-d-dev.txt ├── c3-d-test.txt ├── c3-m-dev.txt └── c3-m-test.txt ├── bert ├── LICENSE ├── __init__.py ├── convert_tf_checkpoint_to_pytorch.py ├── extract_features.py ├── modeling.py ├── optimization.py ├── run_classifier.py └── tokenization.py ├── data ├── c3-d-dev.json ├── c3-d-test.json ├── c3-d-train.json ├── c3-m-dev.json ├── c3-m-test.json └── c3-m-train.json └── license.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | C3 2 | ===== 3 | Overview 4 | -------- 5 | This repository maintains **C3**, the first free-form multiple-**C**hoice **C**hinese machine reading **C**omprehension dataset. 6 | 7 | * Paper: https://arxiv.org/abs/1904.09679 8 | ``` 9 | @article{sun2019investigating, 10 | title={Investigating Prior Knowledge for Challenging Chinese Machine Reading Comprehension}, 11 | author={Sun, Kai and Yu, Dian and Yu, Dong and Cardie, Claire}, 12 | journal={Transactions of the Association for Computational Linguistics}, 13 | year={2020}, 14 | url={https://arxiv.org/abs/1904.09679v3} 15 | } 16 | ``` 17 | 18 | Files in this repository: 19 | 20 | * ```license.txt```: the license of C3. 21 | * ```data/c3-{m,d}-{train,dev,test}.json```: the dataset files, where m and d represent "**m**ixed-genre" and "**d**ialogue", respectively. The data format is as follows. 22 | ``` 23 | [ 24 | [ 25 | [ 26 | document 1 27 | ], 28 | [ 29 | { 30 | "question": document 1 / question 1, 31 | "choice": [ 32 | document 1 / question 1 / answer option 1, 33 | document 1 / question 1 / answer option 2, 34 | ... 35 | ], 36 | "answer": document 1 / question 1 / correct answer option 37 | }, 38 | { 39 | "question": document 1 / question 2, 40 | "choice": [ 41 | document 1 / question 2 / answer option 1, 42 | document 1 / question 2 / answer option 2, 43 | ... 44 | ], 45 | "answer": document 1 / question 2 / correct answer option 46 | }, 47 | ... 48 | ], 49 | document 1 / id 50 | ], 51 | [ 52 | [ 53 | document 2 54 | ], 55 | [ 56 | { 57 | "question": document 2 / question 1, 58 | "choice": [ 59 | document 2 / question 1 / answer option 1, 60 | document 2 / question 1 / answer option 2, 61 | ... 62 | ], 63 | "answer": document 2 / question 1 / correct answer option 64 | }, 65 | { 66 | "question": document 2 / question 2, 67 | "choice": [ 68 | document 2 / question 2 / answer option 1, 69 | document 2 / question 2 / answer option 2, 70 | ... 71 | ], 72 | "answer": document 2 / question 2 / correct answer option 73 | }, 74 | ... 75 | ], 76 | document 2 / id 77 | ], 78 | ... 79 | ] 80 | ``` 81 | * ```annotation/c3-{m,d}-{dev,test}.txt```: question type annotations. Each file contains 150 annotated instances. We adopt the following abbreviations: 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 |
AbbreviationQuestion Type
MatchingmMatching
Prior knowledgelLinguistic
sDomain-specific
c-aArithmetic
c-oConnotation
c-eCause-effect
c-iImplication
c-pPart-whole
c-dPrecondition
c-hScenario
c-nOther
Supporting Sentences0Single Sentence
1Multiple sentences
2Independent
150 | 151 | 152 | * ```bert``` folder: code of Chinese BERT, BERT-wwm, and BERT-wwm-ext baselines. The code is derived from [this repository](https://github.com/nlpdata/mrc_bert_baseline). Below are detailed instructions on fine-tuning Chinese BERT on C3. 153 | 1. Download and unzip the pre-trained Chinese BERT from [here](https://github.com/google-research/bert), and set up the environment variable for BERT by ```export BERT_BASE_DIR=/PATH/TO/BERT/DIR```. 154 | 2. Copy the dataset folder ```data``` to ```bert/```. 155 | 3. In ```bert```, execute ```python convert_tf_checkpoint_to_pytorch.py --tf_checkpoint_path=$BERT_BASE_DIR/bert_model.ckpt --bert_config_file=$BERT_BASE_DIR/bert_config.json --pytorch_dump_path=$BERT_BASE_DIR/pytorch_model.bin```. 156 | 4. Execute ```python run_classifier.py --task_name c3 --do_train --do_eval --data_dir . --vocab_file $BERT_BASE_DIR/vocab.txt --bert_config_file $BERT_BASE_DIR/bert_config.json --init_checkpoint $BERT_BASE_DIR/pytorch_model.bin --max_seq_length 512 --train_batch_size 24 --learning_rate 2e-5 --num_train_epochs 8.0 --output_dir c3_finetuned --gradient_accumulation_steps 3```. 157 | 5. The resulting fine-tuned model, predictions, and evaluation results are stored in ```bert/c3_finetuned```. 158 | 159 | **Note**: 160 | 1. Fine-tuning Chinese BERT-wwm or BERT-wwm-ext follows the same steps except for downloading their pre-trained language models. 161 | 2. There is randomness in model training, so you may want to run multiple times to choose the best model based on development set performance. You may also want to set different seeds (specify ```--seed``` when executing ```run_classifier.py```). 162 | 3. Depending on your hardware, you may need to change ```gradient_accumulation_steps```. 163 | 4. The code has been tested with Python 3.6 and PyTorch 1.0. 164 | -------------------------------------------------------------------------------- /annotation/c3-d-dev.txt: -------------------------------------------------------------------------------- 1 | documentID questionIndex type 2 | 47-275 1 c-i||1 3 | 11-57 1 c-d||1 4 | 11-30 1 m||0 5 | 17-3 1 l||1 6 | m27-11 1 m||0 7 | m27-11 2 m||0 8 | m27-11 3 l||1 9 | m27-11 4 l|c-e||1 10 | m27-11 5 l||1 11 | 46-196 1 c-i||1 12 | 28-8 1 c-o||1 13 | 46-151 1 c-o||1 14 | 2-29 1 m||1 15 | 37-330 1 c-o||1 16 | 7-10 1 l||1 17 | 37-120 1 c-h||1 18 | m12-8 1 c-h||1 19 | m12-8 2 m||1 20 | 32-96 1 c-h||1 21 | 10-81 1 l||1 22 | 57-158 1 c-h||1 23 | 56-261 1 l||0 24 | 34-65 1 c-h||1 25 | 46-40 1 s||1 26 | m19-1 1 l||1 27 | m19-1 2 m||0 28 | m19-1 3 m||0 29 | m19-1 4 l|c-e||1 30 | m19-1 5 l||0 31 | 41-38 1 c-i||1 32 | 46-91 1 c-i||1 33 | 54-217 1 m||0 34 | 51-234 1 c-i||1 35 | 33-72 1 l||1 36 | 34-117 1 l|c-i||1 37 | 53-121 1 c-h||1 38 | 54-206 1 c-d||1 39 | 54-358 1 c-h||1 40 | 54-103 1 m||0 41 | 35-63 1 c-i||1 42 | 34-63 1 c-i||1 43 | 52-141 1 l||1 44 | 42-37 1 m||0 45 | 34-4 1 c-p||0 46 | 56-29 1 c-h||1 47 | 57-214 1 c-h||1 48 | 57-203 1 l|c-e||1 49 | 49-85 1 c-o||1 50 | 37-240 1 c-d|c-h||1 51 | 42-103 1 c-h||1 52 | 38-54 1 l||1 53 | 56-234 1 l||1 54 | 12-62 1 l||1 55 | 41-134 1 s|c-a|c-p||1 56 | 32-200 1 c-i||1 57 | 35-9 1 c-i||1 58 | 5-59 1 l||1 59 | 28-5 1 c-a||1 60 | 17-40 1 c-o||1 61 | 4-14 1 c-o||1 62 | 37-40 1 m||0 63 | 52-112 1 l||0 64 | 56-283 1 l||0 65 | 36-38 1 m||0 66 | m15-4 1 l||1 67 | m15-4 2 c-h||1 68 | 56-19 1 m||0 69 | 52-176 1 m||0 70 | 37-290 1 m||1 71 | 47-40 1 c-i||1 72 | 19-26 1 c-h||1 73 | 6-13 1 c-h||1 74 | 37-341 1 c-e|c-i|l||1 75 | 39-116 1 c-d||1 76 | 29-128 1 c-h||1 77 | 46-178 1 c-i||1 78 | 49-81 1 c-p||0 79 | 3-37 1 c-a||1 80 | 51-14 1 l||1 81 | 56-89 1 m||0 82 | 35-98 1 m||1 83 | 42-63 1 s||1 84 | m29-7 1 l|c-e||0 85 | m29-7 2 l||1 86 | m29-7 3 l||0 87 | m29-7 4 c-h||1 88 | m29-7 5 l||1 89 | 46-85 1 l||1 90 | 44-74 1 c-h||1 91 | 48-117 1 l||1 92 | 57-45 1 c-i||1 93 | 40-192 1 l||1 94 | 40-124 1 c-h||1 95 | 55-63 1 c-h||1 96 | 38-103 1 m||0 97 | 43-43 1 m||1 98 | 37-45 1 c-h||1 99 | 34-182 1 c-h||1 100 | 54-388 1 c-i|c-p||1 101 | 53-106 1 c-o||1 102 | 20-3 1 c-i||0 103 | 45-8 1 m||0 104 | 39-196 1 c-h||1 105 | 50-84 1 c-i|c-p||1 106 | 37-277 1 c-i||1 107 | 47-95 1 c-i||1 108 | 35-60 1 c-h||1 109 | 34-192 1 c-h||1 110 | 48-245 1 c-h|l||1 111 | 54-9 1 l||0 112 | 27-35 1 c-a||1 113 | 37-377 1 c-o||1 114 | 48-37 1 l||1 115 | 20-36 1 l||1 116 | 38-92 1 l||0 117 | 39-244 1 l||1 118 | 29-68 1 m||0 119 | 48-240 1 l||1 120 | 57-134 1 c-i|c-a||1 121 | 56-291 1 m||1 122 | 39-121 1 m||1 123 | 48-193 1 c-h||1 124 | 52-168 1 c-h||1 125 | 11-55 1 l||0 126 | 21-6 1 c-h||1 127 | 56-5 1 c-h||1 128 | 45-33 1 l||1 129 | 37-4 1 l||0 130 | 28-50 1 c-h||1 131 | 5-21 1 c-e||1 132 | 37-231 1 m||1 133 | 54-315 1 m||0 134 | 31-135 1 l|c-e||1 135 | 56-195 1 c-n||1 136 | 29-102 1 l||0 137 | 56-177 1 l||0 138 | 41-132 1 l||1 139 | 54-44 1 l|c-e||1 140 | 20-30 1 l||1 141 | 51-227 1 c-i||1 142 | 9-2 1 m||1 143 | 41-83 1 l|c-e||0 144 | 47-118 1 l||0 145 | 35-71 1 m||1 146 | 55-66 1 c-h||1 147 | 32-175 1 c-i||1 148 | m33-21 1 l||1 149 | m33-21 2 l||1 150 | m33-21 3 c-p||1 151 | m33-21 4 l||1 152 | -------------------------------------------------------------------------------- /annotation/c3-d-test.txt: -------------------------------------------------------------------------------- 1 | documentID questionIndex type 2 | 47-37 1 l||1 3 | 50-94 1 c-i||0 4 | 48-292 1 c-o||0 5 | 47-248 1 c-i||1 6 | 26-30 1 c-e||1 7 | 10-114 1 c-e|c-i||1 8 | 38-119 1 c-d||1 9 | 34-203 1 c-i|c-p||1 10 | m23-4 1 m||0 11 | m23-4 2 l||1 12 | m23-4 3 l||1 13 | m23-4 4 m||1 14 | m23-4 5 m||0 15 | 32-6 1 c-i||0 16 | 37-84 1 l|c-i||1 17 | 47-104 1 c-h|c-i||1 18 | 53-57 1 l|c-h|c-i||0 19 | 49-148 1 c-h||1 20 | 56-212 1 c-h||1 21 | 40-255 1 c-h|l||1 22 | 29-88 1 m||1 23 | m22-24 1 l||0 24 | m22-24 2 c-h||1 25 | m22-24 3 l||1 26 | m22-24 4 l||0 27 | m22-24 5 l||1 28 | 30-174 1 c-i||1 29 | 49-24 1 c-a||1 30 | 31-196 1 c-d||0 31 | 57-58 1 c-h||1 32 | 29-204 1 m||1 33 | 52-9 1 c-h||1 34 | 33-85 1 c-a||1 35 | 37-352 1 c-o||1 36 | 46-271 1 c-h|l||1 37 | 56-49 1 m||0 38 | 10-24 1 c-h||1 39 | 40-73 1 c-i||1 40 | 29-109 1 m||0 41 | 37-363 1 c-i|c-p||1 42 | 34-66 1 c-h||1 43 | 40-32 1 c-p||1 44 | 3-50 1 l||0 45 | 44-116 1 l||1 46 | 32-116 1 l||1 47 | 46-37 1 l||0 48 | 52-79 1 c-h|c-e||1 49 | 56-90 1 c-e|c-i||1 50 | 5-22 1 c-h||1 51 | 54-139 1 c-i||1 52 | 56-165 1 c-i||1 53 | 50-132 1 c-h||1 54 | 49-32 1 c-o||0 55 | 41-216 1 c-h||1 56 | 32-174 1 c-h||1 57 | 29-229 1 c-h||1 58 | 37-379 1 c-i||0 59 | 9-4 1 l||1 60 | 51-47 1 c-h||1 61 | 51-122 1 c-a|l||1 62 | 26-21 1 l||1 63 | 46-150 1 c-i|l||1 64 | 43-8 1 c-i||1 65 | 17-38 1 c-h|c-d||1 66 | 56-102 1 c-h||0 67 | 39-102 1 c-i||0 68 | 24-11 1 c-h||1 69 | 51-149 1 l|c-a||1 70 | 48-202 1 c-e|c-i||1 71 | 13-65 1 c-h||1 72 | 47-231 1 c-i||1 73 | 37-210 1 c-i||0 74 | 15-10 1 c-o||1 75 | 33-100 1 c-a||1 76 | 52-53 1 c-h||0 77 | 33-94 1 c-i||1 78 | 35-99 1 c-e|c-p||1 79 | 19-41 1 c-i||1 80 | 46-243 1 m||0 81 | 12-32 1 c-h||1 82 | 32-77 1 c-h||1 83 | 37-157 1 c-i||1 84 | 57-40 1 m||0 85 | 40-299 1 c-i||1 86 | 41-121 1 c-p|l||1 87 | 10-39 1 c-i||0 88 | 54-223 1 c-o||1 89 | 6-8 1 c-e||0 90 | 46-272 1 c-p|c-a||0 91 | 10-86 1 c-h|l||1 92 | 48-95 1 c-i||0 93 | 10-21 1 c-i||1 94 | 41-211 1 c-i||1 95 | 51-204 1 c-d||0 96 | 39-117 1 c-h||1 97 | 39-132 1 c-e|l||1 98 | 48-214 1 c-h|l||1 99 | m2-3 1 c-d||1 100 | m2-3 2 m||0 101 | m2-3 3 c-i|l||1 102 | m2-3 4 c-p|l||1 103 | m2-3 5 c-i|l||1 104 | 45-130 1 c-o||1 105 | 35-34 1 c-h||1 106 | 54-95 1 c-h||1 107 | 20-21 1 c-i||1 108 | 53-2 1 c-i||1 109 | 56-230 1 c-h|c-d||1 110 | 7-3 1 c-d||1 111 | m23-23 1 m||1 112 | m23-23 2 c-o|l||1 113 | m23-23 3 c-h||0 114 | m23-23 4 c-i||1 115 | m23-23 5 m||1 116 | 57-212 1 c-h||1 117 | 29-43 1 c-i|c-h||1 118 | 32-104 1 m||0 119 | 46-70 1 l||1 120 | 45-52 1 m||0 121 | 5-35 1 l||1 122 | 56-215 1 c-h||1 123 | 32-27 1 c-h||1 124 | 32-9 1 l||1 125 | 37-114 1 c-i||1 126 | m33-12 1 l||1 127 | m33-12 2 l||1 128 | m33-12 3 l||1 129 | m33-12 4 l||2 130 | 5-60 1 c-d||0 131 | 36-16 1 c-e||1 132 | 45-67 1 c-h||1 133 | 40-187 1 l||0 134 | 46-161 1 c-e|l||1 135 | 56-167 1 c-i|c-p||1 136 | 52-103 1 l||1 137 | 10-28 1 m||1 138 | 54-96 1 c-a|l||0 139 | 40-242 1 c-h||1 140 | 23-1 1 c-i||1 141 | 32-90 1 c-e|l||1 142 | 39-236 1 c-p|c-a||1 143 | 20-69 1 l|c-a||1 144 | 49-136 1 c-d||1 145 | 50-19 1 c-h||1 146 | 50-81 1 c-i||1 147 | 12-22 1 c-o||1 148 | 41-102 1 c-h||1 149 | 32-119 1 c-h||1 150 | 56-282 1 c-h||1 151 | m28-9 1 m||0 152 | -------------------------------------------------------------------------------- /annotation/c3-m-dev.txt: -------------------------------------------------------------------------------- 1 | documentID questionIndex type 2 | 11-67 1 c-h||1 3 | 8-707 1 l||0 4 | m1-36 1 l||0 5 | m1-36 2 l||1 6 | m1-36 3 c-i||1 7 | m6-130 1 l||1 8 | m6-130 2 l||1 9 | m6-130 3 l|c-e||1 10 | m6-130 4 l|c-e||1 11 | 8-160 1 c-a|l||0 12 | 11-158 1 c-i||1 13 | 10-74 1 l||0 14 | 9-520 1 c-d||0 15 | 9-269 1 l|c-e||1 16 | 12-165 1 m||0 17 | 9-64 1 c-o||0 18 | 1-109 1 c-i||1 19 | 8-632 1 l||0 20 | 8-333 1 l||0 21 | 12-200 1 c-h||0 22 | 12-47 1 l||0 23 | m1-43 1 l||1 24 | m1-43 2 l|c-e||1 25 | m1-43 3 c-h||1 26 | m1-43 4 c-i|c-p||1 27 | 9-12 1 l||0 28 | m7-55 1 c-e|l||0 29 | m7-55 2 c-e|c-h||1 30 | m7-55 3 l||1 31 | m7-55 4 c-i||1 32 | 10-38 1 l||0 33 | 2-227 1 l||0 34 | m6-132 1 m||0 35 | m6-132 2 c-i||1 36 | 8-381 1 c-h||0 37 | 9-392 1 l||0 38 | 8-684 1 c-d||0 39 | 9-190 1 c-h||0 40 | m12-235 1 c-i||1 41 | m12-235 2 c-a||1 42 | m12-235 3 l||1 43 | m12-235 4 l||1 44 | 9-275 1 c-i||1 45 | m13-358 1 l||0 46 | m13-358 2 l||0 47 | m13-358 3 l||0 48 | m13-358 4 l||0 49 | m13-234 1 l||0 50 | m13-234 2 c-e|l||1 51 | m13-234 3 c-h||0 52 | m13-234 4 c-i||1 53 | m8-101 1 c-i||1 54 | m8-101 2 m||0 55 | m8-101 3 l||0 56 | m8-101 4 c-i||1 57 | 12-258 1 l||0 58 | 11-272 1 c-p|l||0 59 | 2-15 1 m||0 60 | m10-81 1 l||1 61 | m10-81 2 c-e|l||1 62 | m10-81 3 c-e|c-h||1 63 | m10-81 4 c-i||1 64 | 8-436 1 m||0 65 | m13-144 1 c-e|c-h||1 66 | m13-144 2 l||0 67 | m13-144 3 l||0 68 | m13-144 4 l||2 69 | 10-228 1 l||0 70 | m11-217 1 c-i||1 71 | m11-217 2 c-e|c-h||1 72 | m11-217 3 l||0 73 | m11-217 4 c-i||1 74 | 11-14 1 c-e||1 75 | 11-195 1 m||0 76 | m5-49 1 c-e|l||1 77 | m5-49 2 l||1 78 | 9-650 1 c-h||1 79 | m12-111 1 c-e|l||0 80 | m12-111 2 c-i||1 81 | m6-88 1 c-p|l||1 82 | m6-88 2 c-e|l||0 83 | m6-88 3 c-i||0 84 | 11-355 1 c-i||1 85 | 8-84 1 l||0 86 | 6-14 1 c-i||0 87 | 10-165 1 l||0 88 | 9-361 1 l||0 89 | m8-41 1 l||1 90 | m8-41 2 l||0 91 | m13-145 1 l||0 92 | m13-145 2 c-i||1 93 | m13-145 3 c-p|c-h||1 94 | m13-145 4 l||1 95 | 9-497 1 c-i||0 96 | m3-16 1 l||1 97 | m3-16 2 l||1 98 | 9-71 1 l||0 99 | 11-182 1 l||0 100 | m7-32 1 m||0 101 | m7-32 2 m||0 102 | m7-32 3 l||1 103 | 12-218 1 l||0 104 | 2-131 1 c-h||0 105 | 9-350 1 l||0 106 | 9-387 1 l||0 107 | 9-556 1 c-d||0 108 | m11-242 1 l||1 109 | m11-242 2 l||0 110 | m11-242 3 c-e|l||0 111 | m11-242 4 c-i||1 112 | m2-23 1 c-e|l||0 113 | m2-23 2 c-i||1 114 | m12-174 1 l||1 115 | m12-174 2 l||1 116 | m12-174 3 c-i||1 117 | m12-174 4 c-p||c-i||1 118 | m11-239 1 c-e|l||0 119 | m11-239 2 l||0 120 | m11-239 3 m||0 121 | m11-239 4 c-i||1 122 | 2-113 1 l||0 123 | 8-727 1 c-i||0 124 | m5-191 1 m||0 125 | m5-191 2 l||1 126 | m5-191 3 l||0 127 | m5-191 4 l||1 128 | m2-74 1 c-e||1 129 | m2-74 2 c-e|l||1 130 | m2-74 3 l||1 131 | m2-74 4 c-i||1 132 | 10-313 1 l||0 133 | 9-580 1 c-h||0 134 | 8-181 1 c-h||0 135 | m12-218 1 l||1 136 | m12-218 2 c-i||1 137 | m12-218 3 l||1 138 | m12-218 4 c-i||1 139 | m10-47 1 l||2 140 | m10-47 2 l||0 141 | m10-47 3 c-o||1 142 | m10-47 4 l||1 143 | m13-198 1 m||0 144 | m13-198 2 l||1 145 | m13-198 3 c-e||0 146 | 9-577 1 c-i||0 147 | 3-28 1 l||1 148 | 12-236 1 l||1 149 | m11-277 1 l||0 150 | m11-277 2 l||0 151 | m11-277 3 l||0 152 | -------------------------------------------------------------------------------- /annotation/c3-m-test.txt: -------------------------------------------------------------------------------- 1 | documentID questionIndex type 2 | 12-21 1 l||1 3 | 3-66 1 c-i||1 4 | m12-64 1 c-n||1 5 | m12-64 2 m||1 6 | m12-64 3 c-d||1 7 | m12-64 4 l||1 8 | m5-132 1 l||0 9 | m5-132 2 c-n||1 10 | 1-146 1 c-h|l||1 11 | 11-363 1 l||0 12 | 8-671 1 c-i||0 13 | m7-83 1 c-e|c-i||1 14 | m7-83 2 c-i||1 15 | m7-83 3 c-i||1 16 | m7-83 4 c-i||1 17 | 9-336 1 c-h|c-i||1 18 | 12-350 1 l||0 19 | 8-53 1 c-a||0 20 | m12-108 1 c-p|l||0 21 | m12-108 2 c-n||1 22 | m12-108 3 m||0 23 | 9-120 1 c-a|c-n||0 24 | m13-158 1 m||1 25 | m13-158 2 c-o||1 26 | m13-158 3 l||0 27 | 9-133 1 c-h||0 28 | m11-90 1 c-e|c-i||0 29 | m11-90 2 l||0 30 | m11-90 3 l||1 31 | 10-55 1 m||1 32 | m5-112 1 m||0 33 | m5-112 2 m||0 34 | m5-112 3 m||0 35 | m5-112 4 l||1 36 | 12-397 1 l||1 37 | m12-76 1 c-e|l||1 38 | m12-76 2 c-e|l||1 39 | m12-76 3 c-e|c-i||1 40 | m12-76 4 c-i|l||1 41 | m7-129 1 c-e|c-p||1 42 | m7-129 2 m||1 43 | m7-129 3 c-e||0 44 | m7-129 4 c-a|c-p||1 45 | m6-2 1 m||0 46 | m6-2 2 c-a|c-h||1 47 | m6-2 3 c-e||0 48 | m6-2 4 l||0 49 | m12-129 1 m||0 50 | m12-129 2 m||1 51 | m12-129 3 m||0 52 | m8-81 1 c-d||1 53 | m8-81 2 s|c-a||1 54 | m8-81 3 s|c-n||2 55 | m8-81 4 c-n||1 56 | 9-5 1 c-i||0 57 | m11-26 1 m||0 58 | m11-26 2 c-e|l||1 59 | m11-26 3 c-e|c-h||1 60 | 3-3 1 l||0 61 | m12-4 1 c-e|l||1 62 | m12-4 2 c-e|l||1 63 | m12-4 3 c-i||1 64 | 12-406 1 m||1 65 | m12-207 1 l||1 66 | m12-207 2 c-e|l||1 67 | m12-207 3 c-n||1 68 | m12-207 4 l|c-d||1 69 | 9-97 1 c-h||1 70 | 8-572 1 c-h||0 71 | 1-21 1 c-h||0 72 | m11-34 1 c-e|l||0 73 | m11-34 2 m||1 74 | m11-34 3 c-i||1 75 | 9-65 1 c-i||0 76 | 8-306 1 c-h||0 77 | m6-30 1 c-e|l||0 78 | m6-30 2 l|c-p||0 79 | m6-30 3 l||0 80 | m12-172 1 m||1 81 | m12-172 2 c-e|l||1 82 | m12-172 3 m||1 83 | m12-172 4 l||1 84 | 8-255 1 c-n||0 85 | m11-243 1 c-e|l||0 86 | m11-243 2 l||1 87 | m11-243 3 l||0 88 | m11-243 4 l||1 89 | 2-161 1 l||0 90 | 11-36 1 m||0 91 | 12-26 1 l||0 92 | 1-5 1 c-i||1 93 | m10-7 1 c-n||0 94 | m10-7 2 c-p|c-i||1 95 | 8-96 1 c-e|l||1 96 | 10-4 1 c-i||1 97 | 11-81 1 l||0 98 | 8-652 1 c-a|c-p||0 99 | m5-199 1 m||0 100 | m5-199 2 m||0 101 | m5-199 3 m||0 102 | m5-95 1 c-i||1 103 | m5-95 2 l||1 104 | m5-95 3 l||2 105 | m5-95 4 l||2 106 | m5-95 5 c-o||1 107 | m5-129 1 c-h||0 108 | m5-129 2 c-p||0 109 | m5-129 3 l||1 110 | 8-658 1 c-e||0 111 | 8-251 1 l||0 112 | 2-153 1 c-h||0 113 | m5-119 1 c-a||1 114 | m5-119 2 l||0 115 | 8-112 1 l||0 116 | 8-711 1 c-i||0 117 | m2-20 1 m||0 118 | m2-20 2 c-p||0 119 | 12-127 1 l||0 120 | m11-111 1 c-h||1 121 | m11-111 2 l||2 122 | m11-111 3 m||0 123 | m11-111 4 l||1 124 | m11-111 5 c-h||1 125 | 6-15 1 l||0 126 | 8-122 1 c-i||0 127 | 9-438 1 c-d||0 128 | 8-466 1 l||0 129 | 10-316 1 c-h||1 130 | 5-24 1 l||2 131 | 12-60 1 l||0 132 | 1-181 1 l||0 133 | 6-7 1 l||0 134 | 9-44 1 c-i||0 135 | 8-451 1 m||0 136 | 9-245 1 c-d||0 137 | m5-4 1 c-i|c-p||1 138 | m5-4 2 l||0 139 | m5-4 3 l||0 140 | m5-4 4 c-n||1 141 | 12-10 1 c-i||0 142 | 8-557 1 l||0 143 | m1-22 1 l||0 144 | m1-22 2 l||0 145 | m10-20 1 l||1 146 | m10-20 2 c-e|l||0 147 | m10-20 3 c-e|l||1 148 | 3-89 1 m||0 149 | m12-307 1 c-p||0 150 | m12-307 2 l||0 151 | m12-307 3 c-i||1 152 | -------------------------------------------------------------------------------- /bert/LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /bert/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /bert/convert_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import argparse 23 | import tensorflow as tf 24 | import torch 25 | import numpy as np 26 | 27 | from modeling import BertConfig, BertModel 28 | 29 | parser = argparse.ArgumentParser() 30 | 31 | ## Required parameters 32 | parser.add_argument("--tf_checkpoint_path", 33 | default = None, 34 | type = str, 35 | required = True, 36 | help = "Path the TensorFlow checkpoint path.") 37 | parser.add_argument("--bert_config_file", 38 | default = None, 39 | type = str, 40 | required = True, 41 | help = "The config json file corresponding to the pre-trained BERT model. \n" 42 | "This specifies the model architecture.") 43 | parser.add_argument("--pytorch_dump_path", 44 | default = None, 45 | type = str, 46 | required = True, 47 | help = "Path to the output PyTorch model.") 48 | 49 | args = parser.parse_args() 50 | 51 | def convert(): 52 | # Initialise PyTorch model 53 | config = BertConfig.from_json_file(args.bert_config_file) 54 | model = BertModel(config) 55 | 56 | # Load weights from TF model 57 | path = args.tf_checkpoint_path 58 | print("Converting TensorFlow checkpoint from {}".format(path)) 59 | 60 | init_vars = tf.train.list_variables(path) 61 | names = [] 62 | arrays = [] 63 | for name, shape in init_vars: 64 | print("Loading {} with shape {}".format(name, shape)) 65 | array = tf.train.load_variable(path, name) 66 | print("Numpy array shape {}".format(array.shape)) 67 | names.append(name) 68 | arrays.append(array) 69 | 70 | for name, array in zip(names, arrays): 71 | name = name[5:] # skip "bert/" 72 | print("Loading {}".format(name)) 73 | name = name.split('/') 74 | if name[0] in ['redictions', 'eq_relationship']: 75 | print("Skipping") 76 | continue 77 | pointer = model 78 | for m_name in name: 79 | if re.fullmatch(r'[A-Za-z]+_\d+', m_name): 80 | l = re.split(r'_(\d+)', m_name) 81 | else: 82 | l = [m_name] 83 | if l[0] == 'kernel': 84 | pointer = getattr(pointer, 'weight') 85 | else: 86 | pointer = getattr(pointer, l[0]) 87 | if len(l) >= 2: 88 | num = int(l[1]) 89 | pointer = pointer[num] 90 | if m_name[-11:] == '_embeddings': 91 | pointer = getattr(pointer, 'weight') 92 | elif m_name == 'kernel': 93 | array = np.transpose(array) 94 | try: 95 | assert pointer.shape == array.shape 96 | except AssertionError as e: 97 | e.args += (pointer.shape, array.shape) 98 | raise 99 | pointer.data = torch.from_numpy(array) 100 | 101 | # Save pytorch-model 102 | torch.save(model.state_dict(), args.pytorch_dump_path) 103 | 104 | if __name__ == "__main__": 105 | convert() 106 | -------------------------------------------------------------------------------- /bert/extract_features.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Extract pre-computed feature vectors from a PyTorch BERT model.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import argparse 22 | import codecs 23 | import collections 24 | import logging 25 | import json 26 | import re 27 | 28 | import torch 29 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler 30 | from torch.utils.data.distributed import DistributedSampler 31 | 32 | import tokenization 33 | from modeling import BertConfig, BertModel 34 | 35 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 36 | datefmt = '%m/%d/%Y %H:%M:%S', 37 | level = logging.INFO) 38 | logger = logging.getLogger(__name__) 39 | 40 | 41 | class InputExample(object): 42 | 43 | def __init__(self, unique_id, text_a, text_b): 44 | self.unique_id = unique_id 45 | self.text_a = text_a 46 | self.text_b = text_b 47 | 48 | 49 | class InputFeatures(object): 50 | """A single set of features of data.""" 51 | 52 | def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids): 53 | self.unique_id = unique_id 54 | self.tokens = tokens 55 | self.input_ids = input_ids 56 | self.input_mask = input_mask 57 | self.input_type_ids = input_type_ids 58 | 59 | 60 | def convert_examples_to_features(examples, seq_length, tokenizer): 61 | """Loads a data file into a list of `InputBatch`s.""" 62 | 63 | features = [] 64 | for (ex_index, example) in enumerate(examples): 65 | tokens_a = tokenizer.tokenize(example.text_a) 66 | 67 | tokens_b = None 68 | if example.text_b: 69 | tokens_b = tokenizer.tokenize(example.text_b) 70 | 71 | if tokens_b: 72 | # Modifies `tokens_a` and `tokens_b` in place so that the total 73 | # length is less than the specified length. 74 | # Account for [CLS], [SEP], [SEP] with "- 3" 75 | _truncate_seq_pair(tokens_a, tokens_b, seq_length - 3) 76 | else: 77 | # Account for [CLS] and [SEP] with "- 2" 78 | if len(tokens_a) > seq_length - 2: 79 | tokens_a = tokens_a[0:(seq_length - 2)] 80 | 81 | # The convention in BERT is: 82 | # (a) For sequence pairs: 83 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 84 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 85 | # (b) For single sequences: 86 | # tokens: [CLS] the dog is hairy . [SEP] 87 | # type_ids: 0 0 0 0 0 0 0 88 | # 89 | # Where "type_ids" are used to indicate whether this is the first 90 | # sequence or the second sequence. The embedding vectors for `type=0` and 91 | # `type=1` were learned during pre-training and are added to the wordpiece 92 | # embedding vector (and position vector). This is not *strictly* necessary 93 | # since the [SEP] token unambigiously separates the sequences, but it makes 94 | # it easier for the model to learn the concept of sequences. 95 | # 96 | # For classification tasks, the first vector (corresponding to [CLS]) is 97 | # used as as the "sentence vector". Note that this only makes sense because 98 | # the entire model is fine-tuned. 99 | tokens = [] 100 | input_type_ids = [] 101 | tokens.append("[CLS]") 102 | input_type_ids.append(0) 103 | for token in tokens_a: 104 | tokens.append(token) 105 | input_type_ids.append(0) 106 | tokens.append("[SEP]") 107 | input_type_ids.append(0) 108 | 109 | if tokens_b: 110 | for token in tokens_b: 111 | tokens.append(token) 112 | input_type_ids.append(1) 113 | tokens.append("[SEP]") 114 | input_type_ids.append(1) 115 | 116 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 117 | 118 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 119 | # tokens are attended to. 120 | input_mask = [1] * len(input_ids) 121 | 122 | # Zero-pad up to the sequence length. 123 | while len(input_ids) < seq_length: 124 | input_ids.append(0) 125 | input_mask.append(0) 126 | input_type_ids.append(0) 127 | 128 | assert len(input_ids) == seq_length 129 | assert len(input_mask) == seq_length 130 | assert len(input_type_ids) == seq_length 131 | 132 | if ex_index < 5: 133 | logger.info("*** Example ***") 134 | logger.info("unique_id: %s" % (example.unique_id)) 135 | logger.info("tokens: %s" % " ".join([str(x) for x in tokens])) 136 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 137 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 138 | logger.info( 139 | "input_type_ids: %s" % " ".join([str(x) for x in input_type_ids])) 140 | 141 | features.append( 142 | InputFeatures( 143 | unique_id=example.unique_id, 144 | tokens=tokens, 145 | input_ids=input_ids, 146 | input_mask=input_mask, 147 | input_type_ids=input_type_ids)) 148 | return features 149 | 150 | 151 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 152 | """Truncates a sequence pair in place to the maximum length.""" 153 | 154 | # This is a simple heuristic which will always truncate the longer sequence 155 | # one token at a time. This makes more sense than truncating an equal percent 156 | # of tokens from each, since if one sequence is very short then each token 157 | # that's truncated likely contains more information than a longer sequence. 158 | while True: 159 | total_length = len(tokens_a) + len(tokens_b) 160 | if total_length <= max_length: 161 | break 162 | if len(tokens_a) > len(tokens_b): 163 | tokens_a.pop() 164 | else: 165 | tokens_b.pop() 166 | 167 | 168 | def read_examples(input_file): 169 | """Read a list of `InputExample`s from an input file.""" 170 | examples = [] 171 | unique_id = 0 172 | with open(input_file, "r") as reader: 173 | while True: 174 | line = tokenization.convert_to_unicode(reader.readline()) 175 | if not line: 176 | break 177 | line = line.strip() 178 | text_a = None 179 | text_b = None 180 | m = re.match(r"^(.*) \|\|\| (.*)$", line) 181 | if m is None: 182 | text_a = line 183 | else: 184 | text_a = m.group(1) 185 | text_b = m.group(2) 186 | examples.append( 187 | InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b)) 188 | unique_id += 1 189 | return examples 190 | 191 | 192 | def main(): 193 | parser = argparse.ArgumentParser() 194 | 195 | ## Required parameters 196 | parser.add_argument("--input_file", default=None, type=str, required=True) 197 | parser.add_argument("--vocab_file", default=None, type=str, required=True, 198 | help="The vocabulary file that the BERT model was trained on.") 199 | parser.add_argument("--output_file", default=None, type=str, required=True) 200 | parser.add_argument("--bert_config_file", default=None, type=str, required=True, 201 | help="The config json file corresponding to the pre-trained BERT model. " 202 | "This specifies the model architecture.") 203 | parser.add_argument("--init_checkpoint", default=None, type=str, required=True, 204 | help="Initial checkpoint (usually from a pre-trained BERT model).") 205 | 206 | ## Other parameters 207 | parser.add_argument("--layers", default="-1,-2,-3,-4", type=str) 208 | parser.add_argument("--max_seq_length", default=128, type=int, 209 | help="The maximum total input sequence length after WordPiece tokenization. Sequences longer " 210 | "than this will be truncated, and sequences shorter than this will be padded.") 211 | parser.add_argument("--do_lower_case", default=True, action='store_true', 212 | help="Whether to lower case the input text. Should be True for uncased " 213 | "models and False for cased models.") 214 | parser.add_argument("--batch_size", default=32, type=int, help="Batch size for predictions.") 215 | parser.add_argument("--local_rank", 216 | type=int, 217 | default=-1, 218 | help = "local_rank for distributed training on gpus") 219 | 220 | args = parser.parse_args() 221 | 222 | if args.local_rank == -1 or args.no_cuda: 223 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 224 | n_gpu = torch.cuda.device_count() 225 | else: 226 | device = torch.device("cuda", args.local_rank) 227 | n_gpu = 1 228 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 229 | torch.distributed.init_process_group(backend='nccl') 230 | logger.info("device", device, "n_gpu", n_gpu, "distributed training", bool(args.local_rank != -1)) 231 | 232 | layer_indexes = [int(x) for x in args.layers.split(",")] 233 | 234 | bert_config = BertConfig.from_json_file(args.bert_config_file) 235 | 236 | tokenizer = tokenization.FullTokenizer( 237 | vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) 238 | 239 | examples = read_examples(args.input_file) 240 | 241 | features = convert_examples_to_features( 242 | examples=examples, seq_length=args.max_seq_length, tokenizer=tokenizer) 243 | 244 | unique_id_to_feature = {} 245 | for feature in features: 246 | unique_id_to_feature[feature.unique_id] = feature 247 | 248 | model = BertModel(bert_config) 249 | if args.init_checkpoint is not None: 250 | model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) 251 | model.to(device) 252 | 253 | if args.local_rank != -1: 254 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 255 | output_device=args.local_rank) 256 | elif n_gpu > 1: 257 | model = torch.nn.DataParallel(model) 258 | 259 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 260 | all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) 261 | all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) 262 | 263 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_example_index) 264 | if args.local_rank == -1: 265 | eval_sampler = SequentialSampler(eval_data) 266 | else: 267 | eval_sampler = DistributedSampler(eval_data) 268 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.batch_size) 269 | 270 | model.eval() 271 | with open(args.output_file, "w", encoding='utf-8') as writer: 272 | for input_ids, input_mask, example_indices in eval_dataloader: 273 | input_ids = input_ids.to(device) 274 | input_mask = input_mask.to(device) 275 | 276 | all_encoder_layers, _ = model(input_ids, token_type_ids=None, attention_mask=input_mask) 277 | all_encoder_layers = all_encoder_layers 278 | 279 | for b, example_index in enumerate(example_indices): 280 | feature = features[example_index.item()] 281 | unique_id = int(feature.unique_id) 282 | # feature = unique_id_to_feature[unique_id] 283 | output_json = collections.OrderedDict() 284 | output_json["linex_index"] = unique_id 285 | all_out_features = [] 286 | for (i, token) in enumerate(feature.tokens): 287 | all_layers = [] 288 | for (j, layer_index) in enumerate(layer_indexes): 289 | layer_output = all_encoder_layers[int(layer_index)].detach().cpu().numpy() 290 | layer_output = layer_output[b] 291 | layers = collections.OrderedDict() 292 | layers["index"] = layer_index 293 | layers["values"] = [ 294 | round(x.item(), 6) for x in layer_output[i] 295 | ] 296 | all_layers.append(layers) 297 | out_features = collections.OrderedDict() 298 | out_features["token"] = token 299 | out_features["layers"] = all_layers 300 | all_out_features.append(out_features) 301 | output_json["features"] = all_out_features 302 | writer.write(json.dumps(output_json) + "\n") 303 | 304 | 305 | if __name__ == "__main__": 306 | main() 307 | -------------------------------------------------------------------------------- /bert/modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch BERT model.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import copy 22 | import json 23 | import math 24 | import six 25 | import torch 26 | import torch.nn as nn 27 | from torch.nn import CrossEntropyLoss 28 | 29 | def gelu(x): 30 | """Implementation of the gelu activation function. 31 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 32 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 33 | """ 34 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 35 | 36 | 37 | class BertConfig(object): 38 | """Configuration class to store the configuration of a `BertModel`. 39 | """ 40 | def __init__(self, 41 | vocab_size, 42 | hidden_size=768, 43 | num_hidden_layers=12, 44 | num_attention_heads=12, 45 | intermediate_size=3072, 46 | hidden_act="gelu", 47 | hidden_dropout_prob=0.1, 48 | attention_probs_dropout_prob=0.1, 49 | max_position_embeddings=512, 50 | type_vocab_size=16, 51 | initializer_range=0.02): 52 | """Constructs BertConfig. 53 | 54 | Args: 55 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. 56 | hidden_size: Size of the encoder layers and the pooler layer. 57 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 58 | num_attention_heads: Number of attention heads for each attention layer in 59 | the Transformer encoder. 60 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 61 | layer in the Transformer encoder. 62 | hidden_act: The non-linear activation function (function or string) in the 63 | encoder and pooler. 64 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 65 | layers in the embeddings, encoder, and pooler. 66 | attention_probs_dropout_prob: The dropout ratio for the attention 67 | probabilities. 68 | max_position_embeddings: The maximum sequence length that this model might 69 | ever be used with. Typically set this to something large just in case 70 | (e.g., 512 or 1024 or 2048). 71 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 72 | `BertModel`. 73 | initializer_range: The sttdev of the truncated_normal_initializer for 74 | initializing all weight matrices. 75 | """ 76 | self.vocab_size = vocab_size 77 | self.hidden_size = hidden_size 78 | self.num_hidden_layers = num_hidden_layers 79 | self.num_attention_heads = num_attention_heads 80 | self.hidden_act = hidden_act 81 | self.intermediate_size = intermediate_size 82 | self.hidden_dropout_prob = hidden_dropout_prob 83 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 84 | self.max_position_embeddings = max_position_embeddings 85 | self.type_vocab_size = type_vocab_size 86 | self.initializer_range = initializer_range 87 | 88 | @classmethod 89 | def from_dict(cls, json_object): 90 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 91 | config = BertConfig(vocab_size=None) 92 | for (key, value) in six.iteritems(json_object): 93 | config.__dict__[key] = value 94 | return config 95 | 96 | @classmethod 97 | def from_json_file(cls, json_file): 98 | """Constructs a `BertConfig` from a json file of parameters.""" 99 | with open(json_file, "r") as reader: 100 | text = reader.read() 101 | return cls.from_dict(json.loads(text)) 102 | 103 | def to_dict(self): 104 | """Serializes this instance to a Python dictionary.""" 105 | output = copy.deepcopy(self.__dict__) 106 | return output 107 | 108 | def to_json_string(self): 109 | """Serializes this instance to a JSON string.""" 110 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 111 | 112 | 113 | class BERTLayerNorm(nn.Module): 114 | def __init__(self, config, variance_epsilon=1e-12): 115 | """Construct a layernorm module in the TF style (epsilon inside the square root). 116 | """ 117 | super(BERTLayerNorm, self).__init__() 118 | self.gamma = nn.Parameter(torch.ones(config.hidden_size)) 119 | self.beta = nn.Parameter(torch.zeros(config.hidden_size)) 120 | self.variance_epsilon = variance_epsilon 121 | 122 | def forward(self, x): 123 | u = x.mean(-1, keepdim=True) 124 | s = (x - u).pow(2).mean(-1, keepdim=True) 125 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 126 | return self.gamma * x + self.beta 127 | 128 | class BERTEmbeddings(nn.Module): 129 | def __init__(self, config): 130 | super(BERTEmbeddings, self).__init__() 131 | """Construct the embedding module from word, position and token_type embeddings. 132 | """ 133 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) 134 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 135 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 136 | 137 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 138 | # any TensorFlow checkpoint file 139 | self.LayerNorm = BERTLayerNorm(config) 140 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 141 | 142 | def forward(self, input_ids, token_type_ids=None): 143 | seq_length = input_ids.size(1) 144 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 145 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 146 | if token_type_ids is None: 147 | token_type_ids = torch.zeros_like(input_ids) 148 | 149 | words_embeddings = self.word_embeddings(input_ids) 150 | position_embeddings = self.position_embeddings(position_ids) 151 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 152 | 153 | embeddings = words_embeddings + position_embeddings + token_type_embeddings 154 | embeddings = self.LayerNorm(embeddings) 155 | embeddings = self.dropout(embeddings) 156 | return embeddings 157 | 158 | 159 | class BERTSelfAttention(nn.Module): 160 | def __init__(self, config): 161 | super(BERTSelfAttention, self).__init__() 162 | if config.hidden_size % config.num_attention_heads != 0: 163 | raise ValueError( 164 | "The hidden size (%d) is not a multiple of the number of attention " 165 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 166 | self.num_attention_heads = config.num_attention_heads 167 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 168 | self.all_head_size = self.num_attention_heads * self.attention_head_size 169 | 170 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 171 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 172 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 173 | 174 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 175 | 176 | def transpose_for_scores(self, x): 177 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 178 | x = x.view(*new_x_shape) 179 | return x.permute(0, 2, 1, 3) 180 | 181 | def forward(self, hidden_states, attention_mask): 182 | mixed_query_layer = self.query(hidden_states) 183 | mixed_key_layer = self.key(hidden_states) 184 | mixed_value_layer = self.value(hidden_states) 185 | 186 | query_layer = self.transpose_for_scores(mixed_query_layer) 187 | key_layer = self.transpose_for_scores(mixed_key_layer) 188 | value_layer = self.transpose_for_scores(mixed_value_layer) 189 | 190 | # Take the dot product between "query" and "key" to get the raw attention scores. 191 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 192 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 193 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 194 | attention_scores = attention_scores + attention_mask 195 | 196 | # Normalize the attention scores to probabilities. 197 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 198 | 199 | # This is actually dropping out entire tokens to attend to, which might 200 | # seem a bit unusual, but is taken from the original Transformer paper. 201 | attention_probs = self.dropout(attention_probs) 202 | 203 | context_layer = torch.matmul(attention_probs, value_layer) 204 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 205 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 206 | context_layer = context_layer.view(*new_context_layer_shape) 207 | return context_layer 208 | 209 | 210 | class BERTSelfOutput(nn.Module): 211 | def __init__(self, config): 212 | super(BERTSelfOutput, self).__init__() 213 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 214 | self.LayerNorm = BERTLayerNorm(config) 215 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 216 | 217 | def forward(self, hidden_states, input_tensor): 218 | hidden_states = self.dense(hidden_states) 219 | hidden_states = self.dropout(hidden_states) 220 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 221 | return hidden_states 222 | 223 | 224 | class BERTAttention(nn.Module): 225 | def __init__(self, config): 226 | super(BERTAttention, self).__init__() 227 | self.self = BERTSelfAttention(config) 228 | self.output = BERTSelfOutput(config) 229 | 230 | def forward(self, input_tensor, attention_mask): 231 | self_output = self.self(input_tensor, attention_mask) 232 | attention_output = self.output(self_output, input_tensor) 233 | return attention_output 234 | 235 | 236 | class BERTIntermediate(nn.Module): 237 | def __init__(self, config): 238 | super(BERTIntermediate, self).__init__() 239 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 240 | self.intermediate_act_fn = gelu 241 | 242 | def forward(self, hidden_states): 243 | hidden_states = self.dense(hidden_states) 244 | hidden_states = self.intermediate_act_fn(hidden_states) 245 | return hidden_states 246 | 247 | 248 | class BERTOutput(nn.Module): 249 | def __init__(self, config): 250 | super(BERTOutput, self).__init__() 251 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 252 | self.LayerNorm = BERTLayerNorm(config) 253 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 254 | 255 | def forward(self, hidden_states, input_tensor): 256 | hidden_states = self.dense(hidden_states) 257 | hidden_states = self.dropout(hidden_states) 258 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 259 | return hidden_states 260 | 261 | 262 | class BERTLayer(nn.Module): 263 | def __init__(self, config): 264 | super(BERTLayer, self).__init__() 265 | self.attention = BERTAttention(config) 266 | self.intermediate = BERTIntermediate(config) 267 | self.output = BERTOutput(config) 268 | 269 | def forward(self, hidden_states, attention_mask): 270 | attention_output = self.attention(hidden_states, attention_mask) 271 | intermediate_output = self.intermediate(attention_output) 272 | layer_output = self.output(intermediate_output, attention_output) 273 | return layer_output 274 | 275 | 276 | class BERTEncoder(nn.Module): 277 | def __init__(self, config): 278 | super(BERTEncoder, self).__init__() 279 | layer = BERTLayer(config) 280 | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) 281 | 282 | def forward(self, hidden_states, attention_mask): 283 | all_encoder_layers = [] 284 | for layer_module in self.layer: 285 | hidden_states = layer_module(hidden_states, attention_mask) 286 | all_encoder_layers.append(hidden_states) 287 | return all_encoder_layers 288 | 289 | 290 | class BERTPooler(nn.Module): 291 | def __init__(self, config): 292 | super(BERTPooler, self).__init__() 293 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 294 | self.activation = nn.Tanh() 295 | 296 | def forward(self, hidden_states): 297 | # We "pool" the model by simply taking the hidden state corresponding 298 | # to the first token. 299 | first_token_tensor = hidden_states[:, 0] 300 | pooled_output = self.dense(first_token_tensor) 301 | pooled_output = self.activation(pooled_output) 302 | return pooled_output 303 | 304 | 305 | class BertModel(nn.Module): 306 | """BERT model ("Bidirectional Embedding Representations from a Transformer"). 307 | 308 | Example usage: 309 | ```python 310 | # Already been converted into WordPiece token ids 311 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 312 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 313 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) 314 | 315 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512, 316 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 317 | 318 | model = modeling.BertModel(config=config) 319 | all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) 320 | ``` 321 | """ 322 | def __init__(self, config: BertConfig): 323 | """Constructor for BertModel. 324 | 325 | Args: 326 | config: `BertConfig` instance. 327 | """ 328 | super(BertModel, self).__init__() 329 | self.embeddings = BERTEmbeddings(config) 330 | self.encoder = BERTEncoder(config) 331 | self.pooler = BERTPooler(config) 332 | 333 | def forward(self, input_ids, token_type_ids=None, attention_mask=None): 334 | if attention_mask is None: 335 | attention_mask = torch.ones_like(input_ids) 336 | if token_type_ids is None: 337 | token_type_ids = torch.zeros_like(input_ids) 338 | 339 | # We create a 3D attention mask from a 2D tensor mask. 340 | # Sizes are [batch_size, 1, 1, to_seq_length] 341 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 342 | # this attention mask is more simple than the triangular masking of causal attention 343 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 344 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 345 | 346 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 347 | # masked positions, this operation will create a tensor which is 0.0 for 348 | # positions we want to attend and -10000.0 for masked positions. 349 | # Since we are adding it to the raw scores before the softmax, this is 350 | # effectively the same as removing these entirely. 351 | extended_attention_mask = extended_attention_mask.float() 352 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 353 | 354 | embedding_output = self.embeddings(input_ids, token_type_ids) 355 | all_encoder_layers = self.encoder(embedding_output, extended_attention_mask) 356 | sequence_output = all_encoder_layers[-1] 357 | pooled_output = self.pooler(sequence_output) 358 | return all_encoder_layers, pooled_output 359 | 360 | class BertForSequenceClassification(nn.Module): 361 | """BERT model for classification. 362 | This module is composed of the BERT model with a linear layer on top of 363 | the pooled output. 364 | 365 | Example usage: 366 | ```python 367 | # Already been converted into WordPiece token ids 368 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 369 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 370 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) 371 | 372 | config = BertConfig(vocab_size=32000, hidden_size=512, 373 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 374 | 375 | num_labels = 2 376 | 377 | model = BertForSequenceClassification(config, num_labels) 378 | logits = model(input_ids, token_type_ids, input_mask) 379 | ``` 380 | """ 381 | def __init__(self, config, num_labels): 382 | super(BertForSequenceClassification, self).__init__() 383 | self.bert = BertModel(config) 384 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 385 | self.classifier = nn.Linear(config.hidden_size, num_labels) 386 | 387 | def init_weights(module): 388 | if isinstance(module, (nn.Linear, nn.Embedding)): 389 | # Slightly different from the TF version which uses truncated_normal for initialization 390 | # cf https://github.com/pytorch/pytorch/pull/5617 391 | module.weight.data.normal_(mean=0.0, std=config.initializer_range) 392 | elif isinstance(module, BERTLayerNorm): 393 | module.beta.data.normal_(mean=0.0, std=config.initializer_range) 394 | module.gamma.data.normal_(mean=0.0, std=config.initializer_range) 395 | if isinstance(module, nn.Linear): 396 | module.bias.data.zero_() 397 | self.apply(init_weights) 398 | 399 | def forward(self, input_ids, token_type_ids, attention_mask, labels=None, n_class=1): 400 | seq_length = input_ids.size(2) 401 | _, pooled_output = self.bert(input_ids.view(-1,seq_length), 402 | token_type_ids.view(-1,seq_length), 403 | attention_mask.view(-1,seq_length)) 404 | pooled_output = self.dropout(pooled_output) 405 | logits = self.classifier(pooled_output) 406 | logits = logits.view(-1, n_class) 407 | 408 | if labels is not None: 409 | loss_fct = CrossEntropyLoss() 410 | labels = labels.view(-1) 411 | loss = loss_fct(logits, labels) 412 | return loss, logits 413 | else: 414 | return logits 415 | 416 | class BertForQuestionAnswering(nn.Module): #FIXME (kai): not modified accordingly yet 417 | """BERT model for Question Answering (span extraction). 418 | This module is composed of the BERT model with a linear layer on top of 419 | the sequence output that computes start_logits and end_logits 420 | 421 | Example usage: 422 | ```python 423 | # Already been converted into WordPiece token ids 424 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 425 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 426 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) 427 | 428 | config = BertConfig(vocab_size=32000, hidden_size=512, 429 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 430 | 431 | model = BertForQuestionAnswering(config) 432 | start_logits, end_logits = model(input_ids, token_type_ids, input_mask) 433 | ``` 434 | """ 435 | def __init__(self, config): 436 | super(BertForQuestionAnswering, self).__init__() 437 | self.bert = BertModel(config) 438 | # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version 439 | # self.dropout = nn.Dropout(config.hidden_dropout_prob) 440 | self.qa_outputs = nn.Linear(config.hidden_size, 2) 441 | 442 | def init_weights(module): 443 | if isinstance(module, (nn.Linear, nn.Embedding)): 444 | # Slightly different from the TF version which uses truncated_normal for initialization 445 | # cf https://github.com/pytorch/pytorch/pull/5617 446 | module.weight.data.normal_(mean=0.0, std=config.initializer_range) 447 | elif isinstance(module, BERTLayerNorm): 448 | module.beta.data.normal_(mean=0.0, std=config.initializer_range) 449 | module.gamma.data.normal_(mean=0.0, std=config.initializer_range) 450 | if isinstance(module, nn.Linear): 451 | module.bias.data.zero_() 452 | self.apply(init_weights) 453 | 454 | def forward(self, input_ids, token_type_ids, attention_mask, start_positions=None, end_positions=None): 455 | all_encoder_layers, _ = self.bert(input_ids, token_type_ids, attention_mask) 456 | sequence_output = all_encoder_layers[-1] 457 | logits = self.qa_outputs(sequence_output) 458 | start_logits, end_logits = logits.split(1, dim=-1) 459 | start_logits = start_logits.squeeze(-1) 460 | end_logits = end_logits.squeeze(-1) 461 | 462 | if start_positions is not None and end_positions is not None: 463 | # If we are on multi-GPU, split add a dimension 464 | if len(start_positions.size()) > 1: 465 | start_positions = start_positions.squeeze(-1) 466 | if len(end_positions.size()) > 1: 467 | end_positions = end_positions.squeeze(-1) 468 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 469 | ignored_index = start_logits.size(1) 470 | start_positions.clamp_(0, ignored_index) 471 | end_positions.clamp_(0, ignored_index) 472 | 473 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 474 | start_loss = loss_fct(start_logits, start_positions) 475 | end_loss = loss_fct(end_logits, end_positions) 476 | total_loss = (start_loss + end_loss) / 2 477 | return total_loss 478 | else: 479 | return start_logits, end_logits 480 | -------------------------------------------------------------------------------- /bert/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.nn.utils import clip_grad_norm_ 21 | 22 | def warmup_cosine(x, warmup=0.002): 23 | if x < warmup: 24 | return x/warmup 25 | return 0.5 * (1.0 + torch.cos(math.pi * x)) 26 | 27 | def warmup_constant(x, warmup=0.002): 28 | if x < warmup: 29 | return x/warmup 30 | return 1.0 31 | 32 | def warmup_linear(x, warmup=0.002): 33 | if x < warmup: 34 | return x/warmup 35 | return 1.0 - x 36 | 37 | SCHEDULES = { 38 | 'warmup_cosine':warmup_cosine, 39 | 'warmup_constant':warmup_constant, 40 | 'warmup_linear':warmup_linear, 41 | } 42 | 43 | 44 | class BERTAdam(Optimizer): 45 | """Implements BERT version of Adam algorithm with weight decay fix (and no ). 46 | Params: 47 | lr: learning rate 48 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 49 | t_total: total number of training steps for the learning 50 | rate schedule, -1 means constant learning rate. Default: -1 51 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 52 | b1: Adams b1. Default: 0.9 53 | b2: Adams b2. Default: 0.999 54 | e: Adams epsilon. Default: 1e-6 55 | weight_decay_rate: Weight decay. Default: 0.01 56 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 57 | """ 58 | def __init__(self, params, lr, warmup=-1, t_total=-1, schedule='warmup_linear', 59 | b1=0.9, b2=0.999, e=1e-6, weight_decay_rate=0.01, 60 | max_grad_norm=1.0): 61 | if not lr >= 0.0: 62 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 63 | if schedule not in SCHEDULES: 64 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 65 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 66 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 67 | if not 0.0 <= b1 < 1.0: 68 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 69 | if not 0.0 <= b2 < 1.0: 70 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 71 | if not e >= 0.0: 72 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 73 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 74 | b1=b1, b2=b2, e=e, weight_decay_rate=weight_decay_rate, 75 | max_grad_norm=max_grad_norm) 76 | super(BERTAdam, self).__init__(params, defaults) 77 | 78 | def get_lr(self): 79 | lr = [] 80 | for group in self.param_groups: 81 | for p in group['params']: 82 | state = self.state[p] 83 | if len(state) == 0: 84 | return [0] 85 | if group['t_total'] != -1: 86 | schedule_fct = SCHEDULES[group['schedule']] 87 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 88 | else: 89 | lr_scheduled = group['lr'] 90 | lr.append(lr_scheduled) 91 | return lr 92 | 93 | def step(self, closure=None): 94 | """Performs a single optimization step. 95 | 96 | Arguments: 97 | closure (callable, optional): A closure that reevaluates the model 98 | and returns the loss. 99 | """ 100 | loss = None 101 | if closure is not None: 102 | loss = closure() 103 | 104 | for group in self.param_groups: 105 | for p in group['params']: 106 | if p.grad is None: 107 | continue 108 | grad = p.grad.data 109 | if grad.is_sparse: 110 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 111 | 112 | state = self.state[p] 113 | 114 | # State initialization 115 | if len(state) == 0: 116 | state['step'] = 0 117 | # Exponential moving average of gradient values 118 | state['next_m'] = torch.zeros_like(p.data) 119 | # Exponential moving average of squared gradient values 120 | state['next_v'] = torch.zeros_like(p.data) 121 | 122 | next_m, next_v = state['next_m'], state['next_v'] 123 | beta1, beta2 = group['b1'], group['b2'] 124 | 125 | # Add grad clipping 126 | if group['max_grad_norm'] > 0: 127 | clip_grad_norm_(p, group['max_grad_norm']) 128 | 129 | # Decay the first and second moment running average coefficient 130 | # In-place operations to update the averages at the same time 131 | next_m.mul_(beta1).add_(1 - beta1, grad) 132 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 133 | update = next_m / (next_v.sqrt() + group['e']) 134 | 135 | # Just adding the square of the weights to the loss function is *not* 136 | # the correct way of using L2 regularization/weight decay with Adam, 137 | # since that will interact with the m and v parameters in strange ways. 138 | # 139 | # Instead we want ot decay the weights in a manner that doesn't interact 140 | # with the m/v parameters. This is equivalent to adding the square 141 | # of the weights to the loss with plain (non-momentum) SGD. 142 | if group['weight_decay_rate'] > 0.0: 143 | update += group['weight_decay_rate'] * p.data 144 | 145 | if group['t_total'] != -1: 146 | schedule_fct = SCHEDULES[group['schedule']] 147 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 148 | else: 149 | lr_scheduled = group['lr'] 150 | 151 | update_with_lr = lr_scheduled * update 152 | p.data.add_(-update_with_lr) 153 | 154 | state['step'] += 1 155 | 156 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 157 | # bias_correction1 = 1 - beta1 ** state['step'] 158 | # bias_correction2 = 1 - beta2 ** state['step'] 159 | 160 | return loss 161 | -------------------------------------------------------------------------------- /bert/run_classifier.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """BERT finetuning runner.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import csv 22 | import os 23 | import logging 24 | import argparse 25 | import random 26 | from tqdm import tqdm, trange 27 | 28 | import numpy as np 29 | import torch 30 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler 31 | from torch.utils.data.distributed import DistributedSampler 32 | 33 | import tokenization 34 | from modeling import BertConfig, BertForSequenceClassification 35 | from optimization import BERTAdam 36 | 37 | import json 38 | 39 | n_class = 4 40 | reverse_order = False 41 | sa_step = False 42 | 43 | 44 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 45 | datefmt = '%m/%d/%Y %H:%M:%S', 46 | level = logging.INFO) 47 | logger = logging.getLogger(__name__) 48 | 49 | 50 | class InputExample(object): 51 | """A single training/test example for simple sequence classification.""" 52 | 53 | def __init__(self, guid, text_a, text_b=None, label=None, text_c=None): 54 | """Constructs a InputExample. 55 | 56 | Args: 57 | guid: Unique id for the example. 58 | text_a: string. The untokenized text of the first sequence. For single 59 | sequence tasks, only this sequence must be specified. 60 | text_b: (Optional) string. The untokenized text of the second sequence. 61 | Only must be specified for sequence pair tasks. 62 | label: (Optional) string. The label of the example. This should be 63 | specified for train and dev examples, but not for test examples. 64 | """ 65 | self.guid = guid 66 | self.text_a = text_a 67 | self.text_b = text_b 68 | self.text_c = text_c 69 | self.label = label 70 | 71 | 72 | class InputFeatures(object): 73 | """A single set of features of data.""" 74 | 75 | def __init__(self, input_ids, input_mask, segment_ids, label_id): 76 | self.input_ids = input_ids 77 | self.input_mask = input_mask 78 | self.segment_ids = segment_ids 79 | self.label_id = label_id 80 | 81 | 82 | class DataProcessor(object): 83 | """Base class for data converters for sequence classification data sets.""" 84 | 85 | def get_train_examples(self, data_dir): 86 | """Gets a collection of `InputExample`s for the train set.""" 87 | raise NotImplementedError() 88 | 89 | def get_dev_examples(self, data_dir): 90 | """Gets a collection of `InputExample`s for the dev set.""" 91 | raise NotImplementedError() 92 | 93 | def get_labels(self): 94 | """Gets the list of labels for this data set.""" 95 | raise NotImplementedError() 96 | 97 | @classmethod 98 | def _read_tsv(cls, input_file, quotechar=None): 99 | """Reads a tab separated value file.""" 100 | with open(input_file, "r") as f: 101 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 102 | lines = [] 103 | for line in reader: 104 | lines.append(line) 105 | return lines 106 | 107 | 108 | class c3Processor(DataProcessor): 109 | def __init__(self): 110 | random.seed(42) 111 | self.D = [[], [], []] 112 | 113 | for sid in range(3): 114 | data = [] 115 | for subtask in ["d", "m"]: 116 | with open("data/c3-"+subtask+"-"+["train.json", "dev.json", "test.json"][sid], "r", encoding="utf8") as f: 117 | data += json.load(f) 118 | if sid == 0: 119 | random.shuffle(data) 120 | for i in range(len(data)): 121 | for j in range(len(data[i][1])): 122 | d = ['\n'.join(data[i][0]).lower(), data[i][1][j]["question"].lower()] 123 | for k in range(len(data[i][1][j]["choice"])): 124 | d += [data[i][1][j]["choice"][k].lower()] 125 | for k in range(len(data[i][1][j]["choice"]), 4): 126 | d += [''] 127 | d += [data[i][1][j]["answer"].lower()] 128 | self.D[sid] += [d] 129 | 130 | def get_train_examples(self, data_dir): 131 | """See base class.""" 132 | return self._create_examples( 133 | self.D[0], "train") 134 | 135 | def get_test_examples(self, data_dir): 136 | """See base class.""" 137 | return self._create_examples( 138 | self.D[2], "test") 139 | 140 | def get_dev_examples(self, data_dir): 141 | """See base class.""" 142 | return self._create_examples( 143 | self.D[1], "dev") 144 | 145 | def get_labels(self): 146 | """See base class.""" 147 | return ["0", "1", "2", "3"] 148 | 149 | def _create_examples(self, data, set_type): 150 | """Creates examples for the training and dev sets.""" 151 | examples = [] 152 | for (i, d) in enumerate(data): 153 | for k in range(4): 154 | if data[i][2+k] == data[i][6]: 155 | answer = str(k) 156 | 157 | label = tokenization.convert_to_unicode(answer) 158 | 159 | for k in range(4): 160 | guid = "%s-%s-%s" % (set_type, i, k) 161 | text_a = tokenization.convert_to_unicode(data[i][0]) 162 | text_b = tokenization.convert_to_unicode(data[i][k+2]) 163 | text_c = tokenization.convert_to_unicode(data[i][1]) 164 | examples.append( 165 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, text_c=text_c)) 166 | 167 | return examples 168 | 169 | 170 | 171 | def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer): 172 | """Loads a data file into a list of `InputBatch`s.""" 173 | 174 | print("#examples", len(examples)) 175 | 176 | label_map = {} 177 | for (i, label) in enumerate(label_list): 178 | label_map[label] = i 179 | 180 | features = [[]] 181 | for (ex_index, example) in enumerate(examples): 182 | tokens_a = tokenizer.tokenize(example.text_a) 183 | 184 | tokens_b = tokenizer.tokenize(example.text_b) 185 | 186 | tokens_c = tokenizer.tokenize(example.text_c) 187 | 188 | _truncate_seq_tuple(tokens_a, tokens_b, tokens_c, max_seq_length - 4) 189 | tokens_b = tokens_c + ["[SEP]"] + tokens_b 190 | 191 | tokens = [] 192 | segment_ids = [] 193 | tokens.append("[CLS]") 194 | segment_ids.append(0) 195 | for token in tokens_a: 196 | tokens.append(token) 197 | segment_ids.append(0) 198 | tokens.append("[SEP]") 199 | segment_ids.append(0) 200 | 201 | if tokens_b: 202 | for token in tokens_b: 203 | tokens.append(token) 204 | segment_ids.append(1) 205 | tokens.append("[SEP]") 206 | segment_ids.append(1) 207 | 208 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 209 | 210 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 211 | # tokens are attended to. 212 | input_mask = [1] * len(input_ids) 213 | 214 | # Zero-pad up to the sequence length. 215 | while len(input_ids) < max_seq_length: 216 | input_ids.append(0) 217 | input_mask.append(0) 218 | segment_ids.append(0) 219 | 220 | assert len(input_ids) == max_seq_length 221 | assert len(input_mask) == max_seq_length 222 | assert len(segment_ids) == max_seq_length 223 | 224 | label_id = label_map[example.label] 225 | if ex_index < 5: 226 | logger.info("*** Example ***") 227 | logger.info("guid: %s" % (example.guid)) 228 | logger.info("tokens: %s" % " ".join( 229 | [tokenization.printable_text(x) for x in tokens])) 230 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 231 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 232 | logger.info( 233 | "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 234 | logger.info("label: %s (id = %d)" % (example.label, label_id)) 235 | 236 | features[-1].append( 237 | InputFeatures( 238 | input_ids=input_ids, 239 | input_mask=input_mask, 240 | segment_ids=segment_ids, 241 | label_id=label_id)) 242 | if len(features[-1]) == n_class: 243 | features.append([]) 244 | 245 | if len(features[-1]) == 0: 246 | features = features[:-1] 247 | print('#features', len(features)) 248 | return features 249 | 250 | 251 | 252 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 253 | """Truncates a sequence pair in place to the maximum length.""" 254 | 255 | # This is a simple heuristic which will always truncate the longer sequence 256 | # one token at a time. This makes more sense than truncating an equal percent 257 | # of tokens from each, since if one sequence is very short then each token 258 | # that's truncated likely contains more information than a longer sequence. 259 | while True: 260 | total_length = len(tokens_a) + len(tokens_b) 261 | if total_length <= max_length: 262 | break 263 | if len(tokens_a) > len(tokens_b): 264 | tokens_a.pop() 265 | else: 266 | tokens_b.pop() 267 | 268 | 269 | def _truncate_seq_tuple(tokens_a, tokens_b, tokens_c, max_length): 270 | """Truncates a sequence tuple in place to the maximum length.""" 271 | 272 | # This is a simple heuristic which will always truncate the longer sequence 273 | # one token at a time. This makes more sense than truncating an equal percent 274 | # of tokens from each, since if one sequence is very short then each token 275 | # that's truncated likely contains more information than a longer sequence. 276 | while True: 277 | total_length = len(tokens_a) + len(tokens_b) + len(tokens_c) 278 | if total_length <= max_length: 279 | break 280 | if len(tokens_a) >= len(tokens_b) and len(tokens_a) >= len(tokens_c): 281 | tokens_a.pop() 282 | elif len(tokens_b) >= len(tokens_a) and len(tokens_b) >= len(tokens_c): 283 | tokens_b.pop() 284 | else: 285 | tokens_c.pop() 286 | 287 | 288 | def accuracy(out, labels): 289 | outputs = np.argmax(out, axis=1) 290 | return np.sum(outputs==labels) 291 | 292 | def main(): 293 | parser = argparse.ArgumentParser() 294 | 295 | ## Required parameters 296 | parser.add_argument("--data_dir", 297 | default=None, 298 | type=str, 299 | required=True, 300 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 301 | parser.add_argument("--bert_config_file", 302 | default=None, 303 | type=str, 304 | required=True, 305 | help="The config json file corresponding to the pre-trained BERT model. \n" 306 | "This specifies the model architecture.") 307 | parser.add_argument("--task_name", 308 | default=None, 309 | type=str, 310 | required=True, 311 | help="The name of the task to train.") 312 | parser.add_argument("--vocab_file", 313 | default=None, 314 | type=str, 315 | required=True, 316 | help="The vocabulary file that the BERT model was trained on.") 317 | parser.add_argument("--output_dir", 318 | default=None, 319 | type=str, 320 | required=True, 321 | help="The output directory where the model checkpoints will be written.") 322 | 323 | ## Other parameters 324 | parser.add_argument("--init_checkpoint", 325 | default=None, 326 | type=str, 327 | help="Initial checkpoint (usually from a pre-trained BERT model).") 328 | parser.add_argument("--do_lower_case", 329 | default=False, 330 | action='store_true', 331 | help="Whether to lower case the input text. True for uncased models, False for cased models.") 332 | parser.add_argument("--max_seq_length", 333 | default=128, 334 | type=int, 335 | help="The maximum total input sequence length after WordPiece tokenization. \n" 336 | "Sequences longer than this will be truncated, and sequences shorter \n" 337 | "than this will be padded.") 338 | parser.add_argument("--do_train", 339 | default=False, 340 | action='store_true', 341 | help="Whether to run training.") 342 | parser.add_argument("--do_eval", 343 | default=False, 344 | action='store_true', 345 | help="Whether to run eval on the dev set.") 346 | parser.add_argument("--train_batch_size", 347 | default=32, 348 | type=int, 349 | help="Total batch size for training.") 350 | parser.add_argument("--eval_batch_size", 351 | default=8, 352 | type=int, 353 | help="Total batch size for eval.") 354 | parser.add_argument("--learning_rate", 355 | default=5e-5, 356 | type=float, 357 | help="The initial learning rate for Adam.") 358 | parser.add_argument("--num_train_epochs", 359 | default=3.0, 360 | type=float, 361 | help="Total number of training epochs to perform.") 362 | parser.add_argument("--warmup_proportion", 363 | default=0.1, 364 | type=float, 365 | help="Proportion of training to perform linear learning rate warmup for. " 366 | "E.g., 0.1 = 10%% of training.") 367 | parser.add_argument("--save_checkpoints_steps", 368 | default=1000, 369 | type=int, 370 | help="How often to save the model checkpoint.") 371 | parser.add_argument("--no_cuda", 372 | default=False, 373 | action='store_true', 374 | help="Whether not to use CUDA when available") 375 | parser.add_argument("--local_rank", 376 | type=int, 377 | default=-1, 378 | help="local_rank for distributed training on gpus") 379 | parser.add_argument('--seed', 380 | type=int, 381 | default=42, 382 | help="random seed for initialization") 383 | parser.add_argument('--gradient_accumulation_steps', 384 | type=int, 385 | default=1, 386 | help="Number of updates steps to accumualte before performing a backward/update pass.") 387 | args = parser.parse_args() 388 | 389 | processors = { 390 | "c3": c3Processor, 391 | } 392 | 393 | if args.local_rank == -1 or args.no_cuda: 394 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 395 | n_gpu = torch.cuda.device_count() 396 | else: 397 | device = torch.device("cuda", args.local_rank) 398 | n_gpu = 1 399 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 400 | torch.distributed.init_process_group(backend='nccl') 401 | logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1)) 402 | 403 | if args.gradient_accumulation_steps < 1: 404 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( 405 | args.gradient_accumulation_steps)) 406 | 407 | args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) 408 | 409 | random.seed(args.seed) 410 | np.random.seed(args.seed) 411 | torch.manual_seed(args.seed) 412 | if n_gpu > 0: 413 | torch.cuda.manual_seed_all(args.seed) 414 | 415 | if not args.do_train and not args.do_eval: 416 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 417 | 418 | bert_config = BertConfig.from_json_file(args.bert_config_file) 419 | 420 | if args.max_seq_length > bert_config.max_position_embeddings: 421 | raise ValueError( 422 | "Cannot use sequence length {} because the BERT model was only trained up to sequence length {}".format( 423 | args.max_seq_length, bert_config.max_position_embeddings)) 424 | 425 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir): 426 | if args.do_train: 427 | raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) 428 | else: 429 | os.makedirs(args.output_dir, exist_ok=True) 430 | 431 | task_name = args.task_name.lower() 432 | 433 | if task_name not in processors: 434 | raise ValueError("Task not found: %s" % (task_name)) 435 | 436 | processor = processors[task_name]() 437 | label_list = processor.get_labels() 438 | 439 | tokenizer = tokenization.FullTokenizer( 440 | vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) 441 | 442 | train_examples = None 443 | num_train_steps = None 444 | if args.do_train: 445 | train_examples = processor.get_train_examples(args.data_dir) 446 | num_train_steps = int( 447 | len(train_examples) / n_class / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs) 448 | 449 | model = BertForSequenceClassification(bert_config, 1 if n_class > 1 else len(label_list)) 450 | if args.init_checkpoint is not None: 451 | model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) 452 | model.to(device) 453 | 454 | if args.local_rank != -1: 455 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 456 | output_device=args.local_rank) 457 | elif n_gpu > 1: 458 | model = torch.nn.DataParallel(model) 459 | 460 | no_decay = ['bias', 'gamma', 'beta'] 461 | optimizer_parameters = [ 462 | {'params': [p for n, p in model.named_parameters() if n not in no_decay], 'weight_decay_rate': 0.01}, 463 | {'params': [p for n, p in model.named_parameters() if n in no_decay], 'weight_decay_rate': 0.0} 464 | ] 465 | 466 | optimizer = BERTAdam(optimizer_parameters, 467 | lr=args.learning_rate, 468 | warmup=args.warmup_proportion, 469 | t_total=num_train_steps) 470 | 471 | global_step = 0 472 | 473 | if args.do_eval: 474 | eval_examples = processor.get_dev_examples(args.data_dir) 475 | eval_features = convert_examples_to_features( 476 | eval_examples, label_list, args.max_seq_length, tokenizer) 477 | 478 | input_ids = [] 479 | input_mask = [] 480 | segment_ids = [] 481 | label_id = [] 482 | 483 | for f in eval_features: 484 | input_ids.append([]) 485 | input_mask.append([]) 486 | segment_ids.append([]) 487 | for i in range(n_class): 488 | input_ids[-1].append(f[i].input_ids) 489 | input_mask[-1].append(f[i].input_mask) 490 | segment_ids[-1].append(f[i].segment_ids) 491 | label_id.append([f[0].label_id]) 492 | 493 | all_input_ids = torch.tensor(input_ids, dtype=torch.long) 494 | all_input_mask = torch.tensor(input_mask, dtype=torch.long) 495 | all_segment_ids = torch.tensor(segment_ids, dtype=torch.long) 496 | all_label_ids = torch.tensor(label_id, dtype=torch.long) 497 | 498 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 499 | if args.local_rank == -1: 500 | eval_sampler = SequentialSampler(eval_data) 501 | else: 502 | eval_sampler = DistributedSampler(eval_data) 503 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) 504 | 505 | 506 | if args.do_train: 507 | best_accuracy = 0 508 | 509 | train_features = convert_examples_to_features( 510 | train_examples, label_list, args.max_seq_length, tokenizer) 511 | logger.info("***** Running training *****") 512 | logger.info(" Num examples = %d", len(train_examples)) 513 | logger.info(" Batch size = %d", args.train_batch_size) 514 | logger.info(" Num steps = %d", num_train_steps) 515 | 516 | input_ids = [] 517 | input_mask = [] 518 | segment_ids = [] 519 | label_id = [] 520 | for f in train_features: 521 | input_ids.append([]) 522 | input_mask.append([]) 523 | segment_ids.append([]) 524 | for i in range(n_class): 525 | input_ids[-1].append(f[i].input_ids) 526 | input_mask[-1].append(f[i].input_mask) 527 | segment_ids[-1].append(f[i].segment_ids) 528 | label_id.append([f[0].label_id]) 529 | 530 | all_input_ids = torch.tensor(input_ids, dtype=torch.long) 531 | all_input_mask = torch.tensor(input_mask, dtype=torch.long) 532 | all_segment_ids = torch.tensor(segment_ids, dtype=torch.long) 533 | all_label_ids = torch.tensor(label_id, dtype=torch.long) 534 | 535 | train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 536 | if args.local_rank == -1: 537 | train_sampler = RandomSampler(train_data) 538 | else: 539 | train_sampler = DistributedSampler(train_data) 540 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) 541 | 542 | for _ in trange(int(args.num_train_epochs), desc="Epoch"): 543 | model.train() 544 | tr_loss = 0 545 | nb_tr_examples, nb_tr_steps = 0, 0 546 | for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): 547 | batch = tuple(t.to(device) for t in batch) 548 | input_ids, input_mask, segment_ids, label_ids = batch 549 | loss, _ = model(input_ids, segment_ids, input_mask, label_ids, n_class) 550 | if n_gpu > 1: 551 | loss = loss.mean() # mean() to average on multi-gpu. 552 | if args.gradient_accumulation_steps > 1: 553 | loss = loss / args.gradient_accumulation_steps 554 | loss.backward() 555 | tr_loss += loss.item() 556 | nb_tr_examples += input_ids.size(0) 557 | nb_tr_steps += 1 558 | if (step + 1) % args.gradient_accumulation_steps == 0: 559 | optimizer.step() # We have accumulated enought gradients 560 | model.zero_grad() 561 | global_step += 1 562 | 563 | model.eval() 564 | eval_loss, eval_accuracy = 0, 0 565 | nb_eval_steps, nb_eval_examples = 0, 0 566 | logits_all = [] 567 | for input_ids, input_mask, segment_ids, label_ids in eval_dataloader: 568 | input_ids = input_ids.to(device) 569 | input_mask = input_mask.to(device) 570 | segment_ids = segment_ids.to(device) 571 | label_ids = label_ids.to(device) 572 | 573 | with torch.no_grad(): 574 | tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids, n_class) 575 | 576 | logits = logits.detach().cpu().numpy() 577 | label_ids = label_ids.to('cpu').numpy() 578 | for i in range(len(logits)): 579 | logits_all += [logits[i]] 580 | 581 | tmp_eval_accuracy = accuracy(logits, label_ids.reshape(-1)) 582 | 583 | eval_loss += tmp_eval_loss.mean().item() 584 | eval_accuracy += tmp_eval_accuracy 585 | 586 | nb_eval_examples += input_ids.size(0) 587 | nb_eval_steps += 1 588 | 589 | eval_loss = eval_loss / nb_eval_steps 590 | eval_accuracy = eval_accuracy / nb_eval_examples 591 | 592 | if args.do_train: 593 | result = {'eval_loss': eval_loss, 594 | 'eval_accuracy': eval_accuracy, 595 | 'global_step': global_step, 596 | 'loss': tr_loss/nb_tr_steps} 597 | else: 598 | result = {'eval_loss': eval_loss, 599 | 'eval_accuracy': eval_accuracy} 600 | 601 | logger.info("***** Eval results *****") 602 | for key in sorted(result.keys()): 603 | logger.info(" %s = %s", key, str(result[key])) 604 | 605 | if eval_accuracy >= best_accuracy: 606 | torch.save(model.state_dict(), os.path.join(args.output_dir, "model_best.pt")) 607 | best_accuracy = eval_accuracy 608 | 609 | model.load_state_dict(torch.load(os.path.join(args.output_dir, "model_best.pt"))) 610 | torch.save(model.state_dict(), os.path.join(args.output_dir, "model.pt")) 611 | 612 | model.load_state_dict(torch.load(os.path.join(args.output_dir, "model.pt"))) 613 | 614 | if args.do_eval: 615 | logger.info("***** Running evaluation *****") 616 | logger.info(" Num examples = %d", len(eval_examples)) 617 | logger.info(" Batch size = %d", args.eval_batch_size) 618 | 619 | model.eval() 620 | eval_loss, eval_accuracy = 0, 0 621 | nb_eval_steps, nb_eval_examples = 0, 0 622 | logits_all = [] 623 | for input_ids, input_mask, segment_ids, label_ids in eval_dataloader: 624 | input_ids = input_ids.to(device) 625 | input_mask = input_mask.to(device) 626 | segment_ids = segment_ids.to(device) 627 | label_ids = label_ids.to(device) 628 | 629 | with torch.no_grad(): 630 | tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids, n_class) 631 | 632 | logits = logits.detach().cpu().numpy() 633 | label_ids = label_ids.to('cpu').numpy() 634 | for i in range(len(logits)): 635 | logits_all += [logits[i]] 636 | 637 | tmp_eval_accuracy = accuracy(logits, label_ids.reshape(-1)) 638 | 639 | eval_loss += tmp_eval_loss.mean().item() 640 | eval_accuracy += tmp_eval_accuracy 641 | 642 | nb_eval_examples += input_ids.size(0) 643 | nb_eval_steps += 1 644 | 645 | eval_loss = eval_loss / nb_eval_steps 646 | eval_accuracy = eval_accuracy / nb_eval_examples 647 | 648 | if args.do_train: 649 | result = {'eval_loss': eval_loss, 650 | 'eval_accuracy': eval_accuracy, 651 | 'global_step': global_step, 652 | 'loss': tr_loss/nb_tr_steps} 653 | else: 654 | result = {'eval_loss': eval_loss, 655 | 'eval_accuracy': eval_accuracy} 656 | 657 | 658 | output_eval_file = os.path.join(args.output_dir, "eval_results_dev.txt") 659 | with open(output_eval_file, "w") as writer: 660 | logger.info("***** Eval results *****") 661 | for key in sorted(result.keys()): 662 | logger.info(" %s = %s", key, str(result[key])) 663 | writer.write("%s = %s\n" % (key, str(result[key]))) 664 | output_eval_file = os.path.join(args.output_dir, "logits_dev.txt") 665 | with open(output_eval_file, "w") as f: 666 | for i in range(len(logits_all)): 667 | for j in range(len(logits_all[i])): 668 | f.write(str(logits_all[i][j])) 669 | if j == len(logits_all[i])-1: 670 | f.write("\n") 671 | else: 672 | f.write(" ") 673 | 674 | 675 | eval_examples = processor.get_test_examples(args.data_dir) 676 | eval_features = convert_examples_to_features( 677 | eval_examples, label_list, args.max_seq_length, tokenizer) 678 | 679 | logger.info("***** Running evaluation *****") 680 | logger.info(" Num examples = %d", len(eval_examples)) 681 | logger.info(" Batch size = %d", args.eval_batch_size) 682 | 683 | input_ids = [] 684 | input_mask = [] 685 | segment_ids = [] 686 | label_id = [] 687 | 688 | for f in eval_features: 689 | input_ids.append([]) 690 | input_mask.append([]) 691 | segment_ids.append([]) 692 | for i in range(n_class): 693 | input_ids[-1].append(f[i].input_ids) 694 | input_mask[-1].append(f[i].input_mask) 695 | segment_ids[-1].append(f[i].segment_ids) 696 | label_id.append([f[0].label_id]) 697 | 698 | all_input_ids = torch.tensor(input_ids, dtype=torch.long) 699 | all_input_mask = torch.tensor(input_mask, dtype=torch.long) 700 | all_segment_ids = torch.tensor(segment_ids, dtype=torch.long) 701 | all_label_ids = torch.tensor(label_id, dtype=torch.long) 702 | 703 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 704 | if args.local_rank == -1: 705 | eval_sampler = SequentialSampler(eval_data) 706 | else: 707 | eval_sampler = DistributedSampler(eval_data) 708 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) 709 | 710 | model.eval() 711 | eval_loss, eval_accuracy = 0, 0 712 | nb_eval_steps, nb_eval_examples = 0, 0 713 | logits_all = [] 714 | for input_ids, input_mask, segment_ids, label_ids in eval_dataloader: 715 | input_ids = input_ids.to(device) 716 | input_mask = input_mask.to(device) 717 | segment_ids = segment_ids.to(device) 718 | label_ids = label_ids.to(device) 719 | 720 | with torch.no_grad(): 721 | tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids, n_class) 722 | 723 | logits = logits.detach().cpu().numpy() 724 | label_ids = label_ids.to('cpu').numpy() 725 | for i in range(len(logits)): 726 | logits_all += [logits[i]] 727 | 728 | tmp_eval_accuracy = accuracy(logits, label_ids.reshape(-1)) 729 | 730 | eval_loss += tmp_eval_loss.mean().item() 731 | eval_accuracy += tmp_eval_accuracy 732 | 733 | nb_eval_examples += input_ids.size(0) 734 | nb_eval_steps += 1 735 | 736 | eval_loss = eval_loss / nb_eval_steps 737 | eval_accuracy = eval_accuracy / nb_eval_examples 738 | 739 | if args.do_train: 740 | result = {'eval_loss': eval_loss, 741 | 'eval_accuracy': eval_accuracy, 742 | 'global_step': global_step, 743 | 'loss': tr_loss/nb_tr_steps} 744 | else: 745 | result = {'eval_loss': eval_loss, 746 | 'eval_accuracy': eval_accuracy} 747 | 748 | 749 | output_eval_file = os.path.join(args.output_dir, "eval_results_test.txt") 750 | with open(output_eval_file, "w") as writer: 751 | logger.info("***** Eval results *****") 752 | for key in sorted(result.keys()): 753 | logger.info(" %s = %s", key, str(result[key])) 754 | writer.write("%s = %s\n" % (key, str(result[key]))) 755 | output_eval_file = os.path.join(args.output_dir, "logits_test.txt") 756 | with open(output_eval_file, "w") as f: 757 | for i in range(len(logits_all)): 758 | for j in range(len(logits_all[i])): 759 | f.write(str(logits_all[i][j])) 760 | if j == len(logits_all[i])-1: 761 | f.write("\n") 762 | else: 763 | f.write(" ") 764 | 765 | if __name__ == "__main__": 766 | main() 767 | -------------------------------------------------------------------------------- /bert/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import re 23 | import unicodedata 24 | import six 25 | 26 | 27 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): 28 | """Checks whether the casing config is consistent with the checkpoint name.""" 29 | 30 | # The casing has to be passed in by the user and there is no explicit check 31 | # as to whether it matches the checkpoint. The casing information probably 32 | # should have been stored in the bert_config.json file, but it's not, so 33 | # we have to heuristically detect it to validate. 34 | 35 | if not init_checkpoint: 36 | return 37 | 38 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) 39 | if m is None: 40 | return 41 | 42 | model_name = m.group(1) 43 | 44 | lower_models = [ 45 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", 46 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" 47 | ] 48 | 49 | cased_models = [ 50 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", 51 | "multi_cased_L-12_H-768_A-12" 52 | ] 53 | 54 | is_bad_config = False 55 | if model_name in lower_models and not do_lower_case: 56 | is_bad_config = True 57 | actual_flag = "False" 58 | case_name = "lowercased" 59 | opposite_flag = "True" 60 | 61 | if model_name in cased_models and do_lower_case: 62 | is_bad_config = True 63 | actual_flag = "True" 64 | case_name = "cased" 65 | opposite_flag = "False" 66 | 67 | if is_bad_config: 68 | raise ValueError( 69 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " 70 | "However, `%s` seems to be a %s model, so you " 71 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches " 72 | "how the model was pre-training. If this error is wrong, please " 73 | "just comment out this check." % (actual_flag, init_checkpoint, 74 | model_name, case_name, opposite_flag)) 75 | 76 | 77 | def convert_to_unicode(text): 78 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 79 | if six.PY3: 80 | if isinstance(text, str): 81 | return text 82 | elif isinstance(text, bytes): 83 | return text.decode("utf-8", "ignore") 84 | else: 85 | raise ValueError("Unsupported string type: %s" % (type(text))) 86 | elif six.PY2: 87 | if isinstance(text, str): 88 | return text.decode("utf-8", "ignore") 89 | elif isinstance(text, unicode): 90 | return text 91 | else: 92 | raise ValueError("Unsupported string type: %s" % (type(text))) 93 | else: 94 | raise ValueError("Not running on Python2 or Python 3?") 95 | 96 | 97 | def printable_text(text): 98 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 99 | 100 | # These functions want `str` for both Python2 and Python3, but in one case 101 | # it's a Unicode string and in the other it's a byte string. 102 | if six.PY3: 103 | if isinstance(text, str): 104 | return text 105 | elif isinstance(text, bytes): 106 | return text.decode("utf-8", "ignore") 107 | else: 108 | raise ValueError("Unsupported string type: %s" % (type(text))) 109 | elif six.PY2: 110 | if isinstance(text, str): 111 | return text 112 | elif isinstance(text, unicode): 113 | return text.encode("utf-8") 114 | else: 115 | raise ValueError("Unsupported string type: %s" % (type(text))) 116 | else: 117 | raise ValueError("Not running on Python2 or Python 3?") 118 | 119 | 120 | def load_vocab(vocab_file): 121 | """Loads a vocabulary file into a dictionary.""" 122 | vocab = collections.OrderedDict() 123 | index = 0 124 | with open(vocab_file, "r", encoding='utf8') as reader: 125 | while True: 126 | token = convert_to_unicode(reader.readline()) 127 | if not token: 128 | break 129 | token = token.strip() 130 | vocab[token] = index 131 | index += 1 132 | return vocab 133 | 134 | 135 | def convert_by_vocab(vocab, items): 136 | """Converts a sequence of [tokens|ids] using the vocab.""" 137 | output = [] 138 | for item in items: 139 | output.append(vocab[item]) 140 | return output 141 | 142 | 143 | def convert_tokens_to_ids(vocab, tokens): 144 | return convert_by_vocab(vocab, tokens) 145 | 146 | 147 | def convert_ids_to_tokens(inv_vocab, ids): 148 | return convert_by_vocab(inv_vocab, ids) 149 | 150 | 151 | def whitespace_tokenize(text): 152 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 153 | text = text.strip() 154 | if not text: 155 | return [] 156 | tokens = text.split() 157 | return tokens 158 | 159 | 160 | class FullTokenizer(object): 161 | """Runs end-to-end tokenziation.""" 162 | 163 | def __init__(self, vocab_file, do_lower_case=True): 164 | self.vocab = load_vocab(vocab_file) 165 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 166 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 167 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 168 | 169 | def tokenize(self, text): 170 | split_tokens = [] 171 | for token in self.basic_tokenizer.tokenize(text): 172 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 173 | split_tokens.append(sub_token) 174 | 175 | return split_tokens 176 | 177 | def convert_tokens_to_ids(self, tokens): 178 | return convert_by_vocab(self.vocab, tokens) 179 | 180 | def convert_ids_to_tokens(self, ids): 181 | return convert_by_vocab(self.inv_vocab, ids) 182 | 183 | 184 | class BasicTokenizer(object): 185 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 186 | 187 | def __init__(self, do_lower_case=True): 188 | """Constructs a BasicTokenizer. 189 | Args: 190 | do_lower_case: Whether to lower case the input. 191 | """ 192 | self.do_lower_case = do_lower_case 193 | 194 | def tokenize(self, text): 195 | """Tokenizes a piece of text.""" 196 | text = convert_to_unicode(text) 197 | text = self._clean_text(text) 198 | 199 | # This was added on November 1st, 2018 for the multilingual and Chinese 200 | # models. This is also applied to the English models now, but it doesn't 201 | # matter since the English models were not trained on any Chinese data 202 | # and generally don't have any Chinese data in them (there are Chinese 203 | # characters in the vocabulary because Wikipedia does have some Chinese 204 | # words in the English Wikipedia.). 205 | text = self._tokenize_chinese_chars(text) 206 | 207 | orig_tokens = whitespace_tokenize(text) 208 | split_tokens = [] 209 | for token in orig_tokens: 210 | if self.do_lower_case: 211 | token = token.lower() 212 | token = self._run_strip_accents(token) 213 | split_tokens.extend(self._run_split_on_punc(token)) 214 | 215 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 216 | return output_tokens 217 | 218 | def _run_strip_accents(self, text): 219 | """Strips accents from a piece of text.""" 220 | text = unicodedata.normalize("NFD", text) 221 | output = [] 222 | for char in text: 223 | cat = unicodedata.category(char) 224 | if cat == "Mn": 225 | continue 226 | output.append(char) 227 | return "".join(output) 228 | 229 | def _run_split_on_punc(self, text): 230 | """Splits punctuation on a piece of text.""" 231 | chars = list(text) 232 | i = 0 233 | start_new_word = True 234 | output = [] 235 | while i < len(chars): 236 | char = chars[i] 237 | if _is_punctuation(char): 238 | output.append([char]) 239 | start_new_word = True 240 | else: 241 | if start_new_word: 242 | output.append([]) 243 | start_new_word = False 244 | output[-1].append(char) 245 | i += 1 246 | 247 | return ["".join(x) for x in output] 248 | 249 | def _tokenize_chinese_chars(self, text): 250 | """Adds whitespace around any CJK character.""" 251 | output = [] 252 | for char in text: 253 | cp = ord(char) 254 | if self._is_chinese_char(cp): 255 | output.append(" ") 256 | output.append(char) 257 | output.append(" ") 258 | else: 259 | output.append(char) 260 | return "".join(output) 261 | 262 | def _is_chinese_char(self, cp): 263 | """Checks whether CP is the codepoint of a CJK character.""" 264 | # This defines a "chinese character" as anything in the CJK Unicode block: 265 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 266 | # 267 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 268 | # despite its name. The modern Korean Hangul alphabet is a different block, 269 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 270 | # space-separated words, so they are not treated specially and handled 271 | # like the all of the other languages. 272 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 273 | (cp >= 0x3400 and cp <= 0x4DBF) or # 274 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 275 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 276 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 277 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 278 | (cp >= 0xF900 and cp <= 0xFAFF) or # 279 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 280 | return True 281 | 282 | return False 283 | 284 | def _clean_text(self, text): 285 | """Performs invalid character removal and whitespace cleanup on text.""" 286 | output = [] 287 | for char in text: 288 | cp = ord(char) 289 | if cp == 0 or cp == 0xfffd or _is_control(char): 290 | continue 291 | if _is_whitespace(char): 292 | output.append(" ") 293 | else: 294 | output.append(char) 295 | return "".join(output) 296 | 297 | 298 | class WordpieceTokenizer(object): 299 | """Runs WordPiece tokenziation.""" 300 | 301 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 302 | self.vocab = vocab 303 | self.unk_token = unk_token 304 | self.max_input_chars_per_word = max_input_chars_per_word 305 | 306 | def tokenize(self, text): 307 | """Tokenizes a piece of text into its word pieces. 308 | This uses a greedy longest-match-first algorithm to perform tokenization 309 | using the given vocabulary. 310 | For example: 311 | input = "unaffable" 312 | output = ["un", "##aff", "##able"] 313 | Args: 314 | text: A single token or whitespace separated tokens. This should have 315 | already been passed through `BasicTokenizer. 316 | Returns: 317 | A list of wordpiece tokens. 318 | """ 319 | 320 | text = convert_to_unicode(text) 321 | 322 | output_tokens = [] 323 | for token in whitespace_tokenize(text): 324 | chars = list(token) 325 | if len(chars) > self.max_input_chars_per_word: 326 | output_tokens.append(self.unk_token) 327 | continue 328 | 329 | is_bad = False 330 | start = 0 331 | sub_tokens = [] 332 | while start < len(chars): 333 | end = len(chars) 334 | cur_substr = None 335 | while start < end: 336 | substr = "".join(chars[start:end]) 337 | if start > 0: 338 | substr = "##" + substr 339 | if substr in self.vocab: 340 | cur_substr = substr 341 | break 342 | end -= 1 343 | if cur_substr is None: 344 | is_bad = True 345 | break 346 | sub_tokens.append(cur_substr) 347 | start = end 348 | 349 | if is_bad: 350 | output_tokens.append(self.unk_token) 351 | else: 352 | output_tokens.extend(sub_tokens) 353 | return output_tokens 354 | 355 | 356 | def _is_whitespace(char): 357 | """Checks whether `chars` is a whitespace character.""" 358 | # \t, \n, and \r are technically contorl characters but we treat them 359 | # as whitespace since they are generally considered as such. 360 | if char == " " or char == "\t" or char == "\n" or char == "\r": 361 | return True 362 | cat = unicodedata.category(char) 363 | if cat == "Zs": 364 | return True 365 | return False 366 | 367 | 368 | def _is_control(char): 369 | """Checks whether `chars` is a control character.""" 370 | # These are technically control characters but we count them as whitespace 371 | # characters. 372 | if char == "\t" or char == "\n" or char == "\r": 373 | return False 374 | cat = unicodedata.category(char) 375 | if cat.startswith("C"): 376 | return True 377 | return False 378 | 379 | 380 | def _is_punctuation(char): 381 | """Checks whether `chars` is a punctuation character.""" 382 | cp = ord(char) 383 | # We treat all non-letter/number ASCII as punctuation. 384 | # Characters such as "^", "$", and "`" are not in the Unicode 385 | # Punctuation class but we treat them as punctuation anyways, for 386 | # consistency. 387 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 388 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 389 | return True 390 | cat = unicodedata.category(char) 391 | if cat.startswith("P"): 392 | return True 393 | return False 394 | -------------------------------------------------------------------------------- /license.txt: -------------------------------------------------------------------------------- 1 | C3 dataset is intended for non-commercial research purpose only. 2 | --------------------------------------------------------------------------------