├── .gitignore
├── README.md
├── annotation
├── c3-d-dev.txt
├── c3-d-test.txt
├── c3-m-dev.txt
└── c3-m-test.txt
├── bert
├── LICENSE
├── __init__.py
├── convert_tf_checkpoint_to_pytorch.py
├── extract_features.py
├── modeling.py
├── optimization.py
├── run_classifier.py
└── tokenization.py
├── data
├── c3-d-dev.json
├── c3-d-test.json
├── c3-d-train.json
├── c3-m-dev.json
├── c3-m-test.json
└── c3-m-train.json
└── license.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | C3
2 | =====
3 | Overview
4 | --------
5 | This repository maintains **C3**, the first free-form multiple-**C**hoice **C**hinese machine reading **C**omprehension dataset.
6 |
7 | * Paper: https://arxiv.org/abs/1904.09679
8 | ```
9 | @article{sun2019investigating,
10 | title={Investigating Prior Knowledge for Challenging Chinese Machine Reading Comprehension},
11 | author={Sun, Kai and Yu, Dian and Yu, Dong and Cardie, Claire},
12 | journal={Transactions of the Association for Computational Linguistics},
13 | year={2020},
14 | url={https://arxiv.org/abs/1904.09679v3}
15 | }
16 | ```
17 |
18 | Files in this repository:
19 |
20 | * ```license.txt```: the license of C3.
21 | * ```data/c3-{m,d}-{train,dev,test}.json```: the dataset files, where m and d represent "**m**ixed-genre" and "**d**ialogue", respectively. The data format is as follows.
22 | ```
23 | [
24 | [
25 | [
26 | document 1
27 | ],
28 | [
29 | {
30 | "question": document 1 / question 1,
31 | "choice": [
32 | document 1 / question 1 / answer option 1,
33 | document 1 / question 1 / answer option 2,
34 | ...
35 | ],
36 | "answer": document 1 / question 1 / correct answer option
37 | },
38 | {
39 | "question": document 1 / question 2,
40 | "choice": [
41 | document 1 / question 2 / answer option 1,
42 | document 1 / question 2 / answer option 2,
43 | ...
44 | ],
45 | "answer": document 1 / question 2 / correct answer option
46 | },
47 | ...
48 | ],
49 | document 1 / id
50 | ],
51 | [
52 | [
53 | document 2
54 | ],
55 | [
56 | {
57 | "question": document 2 / question 1,
58 | "choice": [
59 | document 2 / question 1 / answer option 1,
60 | document 2 / question 1 / answer option 2,
61 | ...
62 | ],
63 | "answer": document 2 / question 1 / correct answer option
64 | },
65 | {
66 | "question": document 2 / question 2,
67 | "choice": [
68 | document 2 / question 2 / answer option 1,
69 | document 2 / question 2 / answer option 2,
70 | ...
71 | ],
72 | "answer": document 2 / question 2 / correct answer option
73 | },
74 | ...
75 | ],
76 | document 2 / id
77 | ],
78 | ...
79 | ]
80 | ```
81 | * ```annotation/c3-{m,d}-{dev,test}.txt```: question type annotations. Each file contains 150 annotated instances. We adopt the following abbreviations:
82 |
83 |
84 |
85 |
86 | |
87 | Abbreviation |
88 | Question Type |
89 |
90 |
91 | Matching |
92 | m |
93 | Matching |
94 |
95 |
96 | Prior knowledge |
97 | l |
98 | Linguistic |
99 |
100 |
101 | s |
102 | Domain-specific |
103 |
104 |
105 | c-a |
106 | Arithmetic |
107 |
108 |
109 | c-o |
110 | Connotation |
111 |
112 |
113 | c-e |
114 | Cause-effect |
115 |
116 |
117 | c-i |
118 | Implication |
119 |
120 |
121 | c-p |
122 | Part-whole |
123 |
124 |
125 | c-d |
126 | Precondition |
127 |
128 |
129 | c-h |
130 | Scenario |
131 |
132 |
133 | c-n |
134 | Other |
135 |
136 |
137 | Supporting Sentences |
138 | 0 |
139 | Single Sentence |
140 |
141 |
142 | 1 |
143 | Multiple sentences |
144 |
145 |
146 | 2 |
147 | Independent |
148 |
149 |
150 |
151 |
152 | * ```bert``` folder: code of Chinese BERT, BERT-wwm, and BERT-wwm-ext baselines. The code is derived from [this repository](https://github.com/nlpdata/mrc_bert_baseline). Below are detailed instructions on fine-tuning Chinese BERT on C3.
153 | 1. Download and unzip the pre-trained Chinese BERT from [here](https://github.com/google-research/bert), and set up the environment variable for BERT by ```export BERT_BASE_DIR=/PATH/TO/BERT/DIR```.
154 | 2. Copy the dataset folder ```data``` to ```bert/```.
155 | 3. In ```bert```, execute ```python convert_tf_checkpoint_to_pytorch.py --tf_checkpoint_path=$BERT_BASE_DIR/bert_model.ckpt --bert_config_file=$BERT_BASE_DIR/bert_config.json --pytorch_dump_path=$BERT_BASE_DIR/pytorch_model.bin```.
156 | 4. Execute ```python run_classifier.py --task_name c3 --do_train --do_eval --data_dir . --vocab_file $BERT_BASE_DIR/vocab.txt --bert_config_file $BERT_BASE_DIR/bert_config.json --init_checkpoint $BERT_BASE_DIR/pytorch_model.bin --max_seq_length 512 --train_batch_size 24 --learning_rate 2e-5 --num_train_epochs 8.0 --output_dir c3_finetuned --gradient_accumulation_steps 3```.
157 | 5. The resulting fine-tuned model, predictions, and evaluation results are stored in ```bert/c3_finetuned```.
158 |
159 | **Note**:
160 | 1. Fine-tuning Chinese BERT-wwm or BERT-wwm-ext follows the same steps except for downloading their pre-trained language models.
161 | 2. There is randomness in model training, so you may want to run multiple times to choose the best model based on development set performance. You may also want to set different seeds (specify ```--seed``` when executing ```run_classifier.py```).
162 | 3. Depending on your hardware, you may need to change ```gradient_accumulation_steps```.
163 | 4. The code has been tested with Python 3.6 and PyTorch 1.0.
164 |
--------------------------------------------------------------------------------
/annotation/c3-d-dev.txt:
--------------------------------------------------------------------------------
1 | documentID questionIndex type
2 | 47-275 1 c-i||1
3 | 11-57 1 c-d||1
4 | 11-30 1 m||0
5 | 17-3 1 l||1
6 | m27-11 1 m||0
7 | m27-11 2 m||0
8 | m27-11 3 l||1
9 | m27-11 4 l|c-e||1
10 | m27-11 5 l||1
11 | 46-196 1 c-i||1
12 | 28-8 1 c-o||1
13 | 46-151 1 c-o||1
14 | 2-29 1 m||1
15 | 37-330 1 c-o||1
16 | 7-10 1 l||1
17 | 37-120 1 c-h||1
18 | m12-8 1 c-h||1
19 | m12-8 2 m||1
20 | 32-96 1 c-h||1
21 | 10-81 1 l||1
22 | 57-158 1 c-h||1
23 | 56-261 1 l||0
24 | 34-65 1 c-h||1
25 | 46-40 1 s||1
26 | m19-1 1 l||1
27 | m19-1 2 m||0
28 | m19-1 3 m||0
29 | m19-1 4 l|c-e||1
30 | m19-1 5 l||0
31 | 41-38 1 c-i||1
32 | 46-91 1 c-i||1
33 | 54-217 1 m||0
34 | 51-234 1 c-i||1
35 | 33-72 1 l||1
36 | 34-117 1 l|c-i||1
37 | 53-121 1 c-h||1
38 | 54-206 1 c-d||1
39 | 54-358 1 c-h||1
40 | 54-103 1 m||0
41 | 35-63 1 c-i||1
42 | 34-63 1 c-i||1
43 | 52-141 1 l||1
44 | 42-37 1 m||0
45 | 34-4 1 c-p||0
46 | 56-29 1 c-h||1
47 | 57-214 1 c-h||1
48 | 57-203 1 l|c-e||1
49 | 49-85 1 c-o||1
50 | 37-240 1 c-d|c-h||1
51 | 42-103 1 c-h||1
52 | 38-54 1 l||1
53 | 56-234 1 l||1
54 | 12-62 1 l||1
55 | 41-134 1 s|c-a|c-p||1
56 | 32-200 1 c-i||1
57 | 35-9 1 c-i||1
58 | 5-59 1 l||1
59 | 28-5 1 c-a||1
60 | 17-40 1 c-o||1
61 | 4-14 1 c-o||1
62 | 37-40 1 m||0
63 | 52-112 1 l||0
64 | 56-283 1 l||0
65 | 36-38 1 m||0
66 | m15-4 1 l||1
67 | m15-4 2 c-h||1
68 | 56-19 1 m||0
69 | 52-176 1 m||0
70 | 37-290 1 m||1
71 | 47-40 1 c-i||1
72 | 19-26 1 c-h||1
73 | 6-13 1 c-h||1
74 | 37-341 1 c-e|c-i|l||1
75 | 39-116 1 c-d||1
76 | 29-128 1 c-h||1
77 | 46-178 1 c-i||1
78 | 49-81 1 c-p||0
79 | 3-37 1 c-a||1
80 | 51-14 1 l||1
81 | 56-89 1 m||0
82 | 35-98 1 m||1
83 | 42-63 1 s||1
84 | m29-7 1 l|c-e||0
85 | m29-7 2 l||1
86 | m29-7 3 l||0
87 | m29-7 4 c-h||1
88 | m29-7 5 l||1
89 | 46-85 1 l||1
90 | 44-74 1 c-h||1
91 | 48-117 1 l||1
92 | 57-45 1 c-i||1
93 | 40-192 1 l||1
94 | 40-124 1 c-h||1
95 | 55-63 1 c-h||1
96 | 38-103 1 m||0
97 | 43-43 1 m||1
98 | 37-45 1 c-h||1
99 | 34-182 1 c-h||1
100 | 54-388 1 c-i|c-p||1
101 | 53-106 1 c-o||1
102 | 20-3 1 c-i||0
103 | 45-8 1 m||0
104 | 39-196 1 c-h||1
105 | 50-84 1 c-i|c-p||1
106 | 37-277 1 c-i||1
107 | 47-95 1 c-i||1
108 | 35-60 1 c-h||1
109 | 34-192 1 c-h||1
110 | 48-245 1 c-h|l||1
111 | 54-9 1 l||0
112 | 27-35 1 c-a||1
113 | 37-377 1 c-o||1
114 | 48-37 1 l||1
115 | 20-36 1 l||1
116 | 38-92 1 l||0
117 | 39-244 1 l||1
118 | 29-68 1 m||0
119 | 48-240 1 l||1
120 | 57-134 1 c-i|c-a||1
121 | 56-291 1 m||1
122 | 39-121 1 m||1
123 | 48-193 1 c-h||1
124 | 52-168 1 c-h||1
125 | 11-55 1 l||0
126 | 21-6 1 c-h||1
127 | 56-5 1 c-h||1
128 | 45-33 1 l||1
129 | 37-4 1 l||0
130 | 28-50 1 c-h||1
131 | 5-21 1 c-e||1
132 | 37-231 1 m||1
133 | 54-315 1 m||0
134 | 31-135 1 l|c-e||1
135 | 56-195 1 c-n||1
136 | 29-102 1 l||0
137 | 56-177 1 l||0
138 | 41-132 1 l||1
139 | 54-44 1 l|c-e||1
140 | 20-30 1 l||1
141 | 51-227 1 c-i||1
142 | 9-2 1 m||1
143 | 41-83 1 l|c-e||0
144 | 47-118 1 l||0
145 | 35-71 1 m||1
146 | 55-66 1 c-h||1
147 | 32-175 1 c-i||1
148 | m33-21 1 l||1
149 | m33-21 2 l||1
150 | m33-21 3 c-p||1
151 | m33-21 4 l||1
152 |
--------------------------------------------------------------------------------
/annotation/c3-d-test.txt:
--------------------------------------------------------------------------------
1 | documentID questionIndex type
2 | 47-37 1 l||1
3 | 50-94 1 c-i||0
4 | 48-292 1 c-o||0
5 | 47-248 1 c-i||1
6 | 26-30 1 c-e||1
7 | 10-114 1 c-e|c-i||1
8 | 38-119 1 c-d||1
9 | 34-203 1 c-i|c-p||1
10 | m23-4 1 m||0
11 | m23-4 2 l||1
12 | m23-4 3 l||1
13 | m23-4 4 m||1
14 | m23-4 5 m||0
15 | 32-6 1 c-i||0
16 | 37-84 1 l|c-i||1
17 | 47-104 1 c-h|c-i||1
18 | 53-57 1 l|c-h|c-i||0
19 | 49-148 1 c-h||1
20 | 56-212 1 c-h||1
21 | 40-255 1 c-h|l||1
22 | 29-88 1 m||1
23 | m22-24 1 l||0
24 | m22-24 2 c-h||1
25 | m22-24 3 l||1
26 | m22-24 4 l||0
27 | m22-24 5 l||1
28 | 30-174 1 c-i||1
29 | 49-24 1 c-a||1
30 | 31-196 1 c-d||0
31 | 57-58 1 c-h||1
32 | 29-204 1 m||1
33 | 52-9 1 c-h||1
34 | 33-85 1 c-a||1
35 | 37-352 1 c-o||1
36 | 46-271 1 c-h|l||1
37 | 56-49 1 m||0
38 | 10-24 1 c-h||1
39 | 40-73 1 c-i||1
40 | 29-109 1 m||0
41 | 37-363 1 c-i|c-p||1
42 | 34-66 1 c-h||1
43 | 40-32 1 c-p||1
44 | 3-50 1 l||0
45 | 44-116 1 l||1
46 | 32-116 1 l||1
47 | 46-37 1 l||0
48 | 52-79 1 c-h|c-e||1
49 | 56-90 1 c-e|c-i||1
50 | 5-22 1 c-h||1
51 | 54-139 1 c-i||1
52 | 56-165 1 c-i||1
53 | 50-132 1 c-h||1
54 | 49-32 1 c-o||0
55 | 41-216 1 c-h||1
56 | 32-174 1 c-h||1
57 | 29-229 1 c-h||1
58 | 37-379 1 c-i||0
59 | 9-4 1 l||1
60 | 51-47 1 c-h||1
61 | 51-122 1 c-a|l||1
62 | 26-21 1 l||1
63 | 46-150 1 c-i|l||1
64 | 43-8 1 c-i||1
65 | 17-38 1 c-h|c-d||1
66 | 56-102 1 c-h||0
67 | 39-102 1 c-i||0
68 | 24-11 1 c-h||1
69 | 51-149 1 l|c-a||1
70 | 48-202 1 c-e|c-i||1
71 | 13-65 1 c-h||1
72 | 47-231 1 c-i||1
73 | 37-210 1 c-i||0
74 | 15-10 1 c-o||1
75 | 33-100 1 c-a||1
76 | 52-53 1 c-h||0
77 | 33-94 1 c-i||1
78 | 35-99 1 c-e|c-p||1
79 | 19-41 1 c-i||1
80 | 46-243 1 m||0
81 | 12-32 1 c-h||1
82 | 32-77 1 c-h||1
83 | 37-157 1 c-i||1
84 | 57-40 1 m||0
85 | 40-299 1 c-i||1
86 | 41-121 1 c-p|l||1
87 | 10-39 1 c-i||0
88 | 54-223 1 c-o||1
89 | 6-8 1 c-e||0
90 | 46-272 1 c-p|c-a||0
91 | 10-86 1 c-h|l||1
92 | 48-95 1 c-i||0
93 | 10-21 1 c-i||1
94 | 41-211 1 c-i||1
95 | 51-204 1 c-d||0
96 | 39-117 1 c-h||1
97 | 39-132 1 c-e|l||1
98 | 48-214 1 c-h|l||1
99 | m2-3 1 c-d||1
100 | m2-3 2 m||0
101 | m2-3 3 c-i|l||1
102 | m2-3 4 c-p|l||1
103 | m2-3 5 c-i|l||1
104 | 45-130 1 c-o||1
105 | 35-34 1 c-h||1
106 | 54-95 1 c-h||1
107 | 20-21 1 c-i||1
108 | 53-2 1 c-i||1
109 | 56-230 1 c-h|c-d||1
110 | 7-3 1 c-d||1
111 | m23-23 1 m||1
112 | m23-23 2 c-o|l||1
113 | m23-23 3 c-h||0
114 | m23-23 4 c-i||1
115 | m23-23 5 m||1
116 | 57-212 1 c-h||1
117 | 29-43 1 c-i|c-h||1
118 | 32-104 1 m||0
119 | 46-70 1 l||1
120 | 45-52 1 m||0
121 | 5-35 1 l||1
122 | 56-215 1 c-h||1
123 | 32-27 1 c-h||1
124 | 32-9 1 l||1
125 | 37-114 1 c-i||1
126 | m33-12 1 l||1
127 | m33-12 2 l||1
128 | m33-12 3 l||1
129 | m33-12 4 l||2
130 | 5-60 1 c-d||0
131 | 36-16 1 c-e||1
132 | 45-67 1 c-h||1
133 | 40-187 1 l||0
134 | 46-161 1 c-e|l||1
135 | 56-167 1 c-i|c-p||1
136 | 52-103 1 l||1
137 | 10-28 1 m||1
138 | 54-96 1 c-a|l||0
139 | 40-242 1 c-h||1
140 | 23-1 1 c-i||1
141 | 32-90 1 c-e|l||1
142 | 39-236 1 c-p|c-a||1
143 | 20-69 1 l|c-a||1
144 | 49-136 1 c-d||1
145 | 50-19 1 c-h||1
146 | 50-81 1 c-i||1
147 | 12-22 1 c-o||1
148 | 41-102 1 c-h||1
149 | 32-119 1 c-h||1
150 | 56-282 1 c-h||1
151 | m28-9 1 m||0
152 |
--------------------------------------------------------------------------------
/annotation/c3-m-dev.txt:
--------------------------------------------------------------------------------
1 | documentID questionIndex type
2 | 11-67 1 c-h||1
3 | 8-707 1 l||0
4 | m1-36 1 l||0
5 | m1-36 2 l||1
6 | m1-36 3 c-i||1
7 | m6-130 1 l||1
8 | m6-130 2 l||1
9 | m6-130 3 l|c-e||1
10 | m6-130 4 l|c-e||1
11 | 8-160 1 c-a|l||0
12 | 11-158 1 c-i||1
13 | 10-74 1 l||0
14 | 9-520 1 c-d||0
15 | 9-269 1 l|c-e||1
16 | 12-165 1 m||0
17 | 9-64 1 c-o||0
18 | 1-109 1 c-i||1
19 | 8-632 1 l||0
20 | 8-333 1 l||0
21 | 12-200 1 c-h||0
22 | 12-47 1 l||0
23 | m1-43 1 l||1
24 | m1-43 2 l|c-e||1
25 | m1-43 3 c-h||1
26 | m1-43 4 c-i|c-p||1
27 | 9-12 1 l||0
28 | m7-55 1 c-e|l||0
29 | m7-55 2 c-e|c-h||1
30 | m7-55 3 l||1
31 | m7-55 4 c-i||1
32 | 10-38 1 l||0
33 | 2-227 1 l||0
34 | m6-132 1 m||0
35 | m6-132 2 c-i||1
36 | 8-381 1 c-h||0
37 | 9-392 1 l||0
38 | 8-684 1 c-d||0
39 | 9-190 1 c-h||0
40 | m12-235 1 c-i||1
41 | m12-235 2 c-a||1
42 | m12-235 3 l||1
43 | m12-235 4 l||1
44 | 9-275 1 c-i||1
45 | m13-358 1 l||0
46 | m13-358 2 l||0
47 | m13-358 3 l||0
48 | m13-358 4 l||0
49 | m13-234 1 l||0
50 | m13-234 2 c-e|l||1
51 | m13-234 3 c-h||0
52 | m13-234 4 c-i||1
53 | m8-101 1 c-i||1
54 | m8-101 2 m||0
55 | m8-101 3 l||0
56 | m8-101 4 c-i||1
57 | 12-258 1 l||0
58 | 11-272 1 c-p|l||0
59 | 2-15 1 m||0
60 | m10-81 1 l||1
61 | m10-81 2 c-e|l||1
62 | m10-81 3 c-e|c-h||1
63 | m10-81 4 c-i||1
64 | 8-436 1 m||0
65 | m13-144 1 c-e|c-h||1
66 | m13-144 2 l||0
67 | m13-144 3 l||0
68 | m13-144 4 l||2
69 | 10-228 1 l||0
70 | m11-217 1 c-i||1
71 | m11-217 2 c-e|c-h||1
72 | m11-217 3 l||0
73 | m11-217 4 c-i||1
74 | 11-14 1 c-e||1
75 | 11-195 1 m||0
76 | m5-49 1 c-e|l||1
77 | m5-49 2 l||1
78 | 9-650 1 c-h||1
79 | m12-111 1 c-e|l||0
80 | m12-111 2 c-i||1
81 | m6-88 1 c-p|l||1
82 | m6-88 2 c-e|l||0
83 | m6-88 3 c-i||0
84 | 11-355 1 c-i||1
85 | 8-84 1 l||0
86 | 6-14 1 c-i||0
87 | 10-165 1 l||0
88 | 9-361 1 l||0
89 | m8-41 1 l||1
90 | m8-41 2 l||0
91 | m13-145 1 l||0
92 | m13-145 2 c-i||1
93 | m13-145 3 c-p|c-h||1
94 | m13-145 4 l||1
95 | 9-497 1 c-i||0
96 | m3-16 1 l||1
97 | m3-16 2 l||1
98 | 9-71 1 l||0
99 | 11-182 1 l||0
100 | m7-32 1 m||0
101 | m7-32 2 m||0
102 | m7-32 3 l||1
103 | 12-218 1 l||0
104 | 2-131 1 c-h||0
105 | 9-350 1 l||0
106 | 9-387 1 l||0
107 | 9-556 1 c-d||0
108 | m11-242 1 l||1
109 | m11-242 2 l||0
110 | m11-242 3 c-e|l||0
111 | m11-242 4 c-i||1
112 | m2-23 1 c-e|l||0
113 | m2-23 2 c-i||1
114 | m12-174 1 l||1
115 | m12-174 2 l||1
116 | m12-174 3 c-i||1
117 | m12-174 4 c-p||c-i||1
118 | m11-239 1 c-e|l||0
119 | m11-239 2 l||0
120 | m11-239 3 m||0
121 | m11-239 4 c-i||1
122 | 2-113 1 l||0
123 | 8-727 1 c-i||0
124 | m5-191 1 m||0
125 | m5-191 2 l||1
126 | m5-191 3 l||0
127 | m5-191 4 l||1
128 | m2-74 1 c-e||1
129 | m2-74 2 c-e|l||1
130 | m2-74 3 l||1
131 | m2-74 4 c-i||1
132 | 10-313 1 l||0
133 | 9-580 1 c-h||0
134 | 8-181 1 c-h||0
135 | m12-218 1 l||1
136 | m12-218 2 c-i||1
137 | m12-218 3 l||1
138 | m12-218 4 c-i||1
139 | m10-47 1 l||2
140 | m10-47 2 l||0
141 | m10-47 3 c-o||1
142 | m10-47 4 l||1
143 | m13-198 1 m||0
144 | m13-198 2 l||1
145 | m13-198 3 c-e||0
146 | 9-577 1 c-i||0
147 | 3-28 1 l||1
148 | 12-236 1 l||1
149 | m11-277 1 l||0
150 | m11-277 2 l||0
151 | m11-277 3 l||0
152 |
--------------------------------------------------------------------------------
/annotation/c3-m-test.txt:
--------------------------------------------------------------------------------
1 | documentID questionIndex type
2 | 12-21 1 l||1
3 | 3-66 1 c-i||1
4 | m12-64 1 c-n||1
5 | m12-64 2 m||1
6 | m12-64 3 c-d||1
7 | m12-64 4 l||1
8 | m5-132 1 l||0
9 | m5-132 2 c-n||1
10 | 1-146 1 c-h|l||1
11 | 11-363 1 l||0
12 | 8-671 1 c-i||0
13 | m7-83 1 c-e|c-i||1
14 | m7-83 2 c-i||1
15 | m7-83 3 c-i||1
16 | m7-83 4 c-i||1
17 | 9-336 1 c-h|c-i||1
18 | 12-350 1 l||0
19 | 8-53 1 c-a||0
20 | m12-108 1 c-p|l||0
21 | m12-108 2 c-n||1
22 | m12-108 3 m||0
23 | 9-120 1 c-a|c-n||0
24 | m13-158 1 m||1
25 | m13-158 2 c-o||1
26 | m13-158 3 l||0
27 | 9-133 1 c-h||0
28 | m11-90 1 c-e|c-i||0
29 | m11-90 2 l||0
30 | m11-90 3 l||1
31 | 10-55 1 m||1
32 | m5-112 1 m||0
33 | m5-112 2 m||0
34 | m5-112 3 m||0
35 | m5-112 4 l||1
36 | 12-397 1 l||1
37 | m12-76 1 c-e|l||1
38 | m12-76 2 c-e|l||1
39 | m12-76 3 c-e|c-i||1
40 | m12-76 4 c-i|l||1
41 | m7-129 1 c-e|c-p||1
42 | m7-129 2 m||1
43 | m7-129 3 c-e||0
44 | m7-129 4 c-a|c-p||1
45 | m6-2 1 m||0
46 | m6-2 2 c-a|c-h||1
47 | m6-2 3 c-e||0
48 | m6-2 4 l||0
49 | m12-129 1 m||0
50 | m12-129 2 m||1
51 | m12-129 3 m||0
52 | m8-81 1 c-d||1
53 | m8-81 2 s|c-a||1
54 | m8-81 3 s|c-n||2
55 | m8-81 4 c-n||1
56 | 9-5 1 c-i||0
57 | m11-26 1 m||0
58 | m11-26 2 c-e|l||1
59 | m11-26 3 c-e|c-h||1
60 | 3-3 1 l||0
61 | m12-4 1 c-e|l||1
62 | m12-4 2 c-e|l||1
63 | m12-4 3 c-i||1
64 | 12-406 1 m||1
65 | m12-207 1 l||1
66 | m12-207 2 c-e|l||1
67 | m12-207 3 c-n||1
68 | m12-207 4 l|c-d||1
69 | 9-97 1 c-h||1
70 | 8-572 1 c-h||0
71 | 1-21 1 c-h||0
72 | m11-34 1 c-e|l||0
73 | m11-34 2 m||1
74 | m11-34 3 c-i||1
75 | 9-65 1 c-i||0
76 | 8-306 1 c-h||0
77 | m6-30 1 c-e|l||0
78 | m6-30 2 l|c-p||0
79 | m6-30 3 l||0
80 | m12-172 1 m||1
81 | m12-172 2 c-e|l||1
82 | m12-172 3 m||1
83 | m12-172 4 l||1
84 | 8-255 1 c-n||0
85 | m11-243 1 c-e|l||0
86 | m11-243 2 l||1
87 | m11-243 3 l||0
88 | m11-243 4 l||1
89 | 2-161 1 l||0
90 | 11-36 1 m||0
91 | 12-26 1 l||0
92 | 1-5 1 c-i||1
93 | m10-7 1 c-n||0
94 | m10-7 2 c-p|c-i||1
95 | 8-96 1 c-e|l||1
96 | 10-4 1 c-i||1
97 | 11-81 1 l||0
98 | 8-652 1 c-a|c-p||0
99 | m5-199 1 m||0
100 | m5-199 2 m||0
101 | m5-199 3 m||0
102 | m5-95 1 c-i||1
103 | m5-95 2 l||1
104 | m5-95 3 l||2
105 | m5-95 4 l||2
106 | m5-95 5 c-o||1
107 | m5-129 1 c-h||0
108 | m5-129 2 c-p||0
109 | m5-129 3 l||1
110 | 8-658 1 c-e||0
111 | 8-251 1 l||0
112 | 2-153 1 c-h||0
113 | m5-119 1 c-a||1
114 | m5-119 2 l||0
115 | 8-112 1 l||0
116 | 8-711 1 c-i||0
117 | m2-20 1 m||0
118 | m2-20 2 c-p||0
119 | 12-127 1 l||0
120 | m11-111 1 c-h||1
121 | m11-111 2 l||2
122 | m11-111 3 m||0
123 | m11-111 4 l||1
124 | m11-111 5 c-h||1
125 | 6-15 1 l||0
126 | 8-122 1 c-i||0
127 | 9-438 1 c-d||0
128 | 8-466 1 l||0
129 | 10-316 1 c-h||1
130 | 5-24 1 l||2
131 | 12-60 1 l||0
132 | 1-181 1 l||0
133 | 6-7 1 l||0
134 | 9-44 1 c-i||0
135 | 8-451 1 m||0
136 | 9-245 1 c-d||0
137 | m5-4 1 c-i|c-p||1
138 | m5-4 2 l||0
139 | m5-4 3 l||0
140 | m5-4 4 c-n||1
141 | 12-10 1 c-i||0
142 | 8-557 1 l||0
143 | m1-22 1 l||0
144 | m1-22 2 l||0
145 | m10-20 1 l||1
146 | m10-20 2 c-e|l||0
147 | m10-20 3 c-e|l||1
148 | 3-89 1 m||0
149 | m12-307 1 c-p||0
150 | m12-307 2 l||0
151 | m12-307 3 c-i||1
152 |
--------------------------------------------------------------------------------
/bert/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/bert/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/bert/convert_tf_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HugginFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert BERT checkpoint."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import re
22 | import argparse
23 | import tensorflow as tf
24 | import torch
25 | import numpy as np
26 |
27 | from modeling import BertConfig, BertModel
28 |
29 | parser = argparse.ArgumentParser()
30 |
31 | ## Required parameters
32 | parser.add_argument("--tf_checkpoint_path",
33 | default = None,
34 | type = str,
35 | required = True,
36 | help = "Path the TensorFlow checkpoint path.")
37 | parser.add_argument("--bert_config_file",
38 | default = None,
39 | type = str,
40 | required = True,
41 | help = "The config json file corresponding to the pre-trained BERT model. \n"
42 | "This specifies the model architecture.")
43 | parser.add_argument("--pytorch_dump_path",
44 | default = None,
45 | type = str,
46 | required = True,
47 | help = "Path to the output PyTorch model.")
48 |
49 | args = parser.parse_args()
50 |
51 | def convert():
52 | # Initialise PyTorch model
53 | config = BertConfig.from_json_file(args.bert_config_file)
54 | model = BertModel(config)
55 |
56 | # Load weights from TF model
57 | path = args.tf_checkpoint_path
58 | print("Converting TensorFlow checkpoint from {}".format(path))
59 |
60 | init_vars = tf.train.list_variables(path)
61 | names = []
62 | arrays = []
63 | for name, shape in init_vars:
64 | print("Loading {} with shape {}".format(name, shape))
65 | array = tf.train.load_variable(path, name)
66 | print("Numpy array shape {}".format(array.shape))
67 | names.append(name)
68 | arrays.append(array)
69 |
70 | for name, array in zip(names, arrays):
71 | name = name[5:] # skip "bert/"
72 | print("Loading {}".format(name))
73 | name = name.split('/')
74 | if name[0] in ['redictions', 'eq_relationship']:
75 | print("Skipping")
76 | continue
77 | pointer = model
78 | for m_name in name:
79 | if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
80 | l = re.split(r'_(\d+)', m_name)
81 | else:
82 | l = [m_name]
83 | if l[0] == 'kernel':
84 | pointer = getattr(pointer, 'weight')
85 | else:
86 | pointer = getattr(pointer, l[0])
87 | if len(l) >= 2:
88 | num = int(l[1])
89 | pointer = pointer[num]
90 | if m_name[-11:] == '_embeddings':
91 | pointer = getattr(pointer, 'weight')
92 | elif m_name == 'kernel':
93 | array = np.transpose(array)
94 | try:
95 | assert pointer.shape == array.shape
96 | except AssertionError as e:
97 | e.args += (pointer.shape, array.shape)
98 | raise
99 | pointer.data = torch.from_numpy(array)
100 |
101 | # Save pytorch-model
102 | torch.save(model.state_dict(), args.pytorch_dump_path)
103 |
104 | if __name__ == "__main__":
105 | convert()
106 |
--------------------------------------------------------------------------------
/bert/extract_features.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Extract pre-computed feature vectors from a PyTorch BERT model."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import argparse
22 | import codecs
23 | import collections
24 | import logging
25 | import json
26 | import re
27 |
28 | import torch
29 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
30 | from torch.utils.data.distributed import DistributedSampler
31 |
32 | import tokenization
33 | from modeling import BertConfig, BertModel
34 |
35 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
36 | datefmt = '%m/%d/%Y %H:%M:%S',
37 | level = logging.INFO)
38 | logger = logging.getLogger(__name__)
39 |
40 |
41 | class InputExample(object):
42 |
43 | def __init__(self, unique_id, text_a, text_b):
44 | self.unique_id = unique_id
45 | self.text_a = text_a
46 | self.text_b = text_b
47 |
48 |
49 | class InputFeatures(object):
50 | """A single set of features of data."""
51 |
52 | def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids):
53 | self.unique_id = unique_id
54 | self.tokens = tokens
55 | self.input_ids = input_ids
56 | self.input_mask = input_mask
57 | self.input_type_ids = input_type_ids
58 |
59 |
60 | def convert_examples_to_features(examples, seq_length, tokenizer):
61 | """Loads a data file into a list of `InputBatch`s."""
62 |
63 | features = []
64 | for (ex_index, example) in enumerate(examples):
65 | tokens_a = tokenizer.tokenize(example.text_a)
66 |
67 | tokens_b = None
68 | if example.text_b:
69 | tokens_b = tokenizer.tokenize(example.text_b)
70 |
71 | if tokens_b:
72 | # Modifies `tokens_a` and `tokens_b` in place so that the total
73 | # length is less than the specified length.
74 | # Account for [CLS], [SEP], [SEP] with "- 3"
75 | _truncate_seq_pair(tokens_a, tokens_b, seq_length - 3)
76 | else:
77 | # Account for [CLS] and [SEP] with "- 2"
78 | if len(tokens_a) > seq_length - 2:
79 | tokens_a = tokens_a[0:(seq_length - 2)]
80 |
81 | # The convention in BERT is:
82 | # (a) For sequence pairs:
83 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
84 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
85 | # (b) For single sequences:
86 | # tokens: [CLS] the dog is hairy . [SEP]
87 | # type_ids: 0 0 0 0 0 0 0
88 | #
89 | # Where "type_ids" are used to indicate whether this is the first
90 | # sequence or the second sequence. The embedding vectors for `type=0` and
91 | # `type=1` were learned during pre-training and are added to the wordpiece
92 | # embedding vector (and position vector). This is not *strictly* necessary
93 | # since the [SEP] token unambigiously separates the sequences, but it makes
94 | # it easier for the model to learn the concept of sequences.
95 | #
96 | # For classification tasks, the first vector (corresponding to [CLS]) is
97 | # used as as the "sentence vector". Note that this only makes sense because
98 | # the entire model is fine-tuned.
99 | tokens = []
100 | input_type_ids = []
101 | tokens.append("[CLS]")
102 | input_type_ids.append(0)
103 | for token in tokens_a:
104 | tokens.append(token)
105 | input_type_ids.append(0)
106 | tokens.append("[SEP]")
107 | input_type_ids.append(0)
108 |
109 | if tokens_b:
110 | for token in tokens_b:
111 | tokens.append(token)
112 | input_type_ids.append(1)
113 | tokens.append("[SEP]")
114 | input_type_ids.append(1)
115 |
116 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
117 |
118 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
119 | # tokens are attended to.
120 | input_mask = [1] * len(input_ids)
121 |
122 | # Zero-pad up to the sequence length.
123 | while len(input_ids) < seq_length:
124 | input_ids.append(0)
125 | input_mask.append(0)
126 | input_type_ids.append(0)
127 |
128 | assert len(input_ids) == seq_length
129 | assert len(input_mask) == seq_length
130 | assert len(input_type_ids) == seq_length
131 |
132 | if ex_index < 5:
133 | logger.info("*** Example ***")
134 | logger.info("unique_id: %s" % (example.unique_id))
135 | logger.info("tokens: %s" % " ".join([str(x) for x in tokens]))
136 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
137 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
138 | logger.info(
139 | "input_type_ids: %s" % " ".join([str(x) for x in input_type_ids]))
140 |
141 | features.append(
142 | InputFeatures(
143 | unique_id=example.unique_id,
144 | tokens=tokens,
145 | input_ids=input_ids,
146 | input_mask=input_mask,
147 | input_type_ids=input_type_ids))
148 | return features
149 |
150 |
151 | def _truncate_seq_pair(tokens_a, tokens_b, max_length):
152 | """Truncates a sequence pair in place to the maximum length."""
153 |
154 | # This is a simple heuristic which will always truncate the longer sequence
155 | # one token at a time. This makes more sense than truncating an equal percent
156 | # of tokens from each, since if one sequence is very short then each token
157 | # that's truncated likely contains more information than a longer sequence.
158 | while True:
159 | total_length = len(tokens_a) + len(tokens_b)
160 | if total_length <= max_length:
161 | break
162 | if len(tokens_a) > len(tokens_b):
163 | tokens_a.pop()
164 | else:
165 | tokens_b.pop()
166 |
167 |
168 | def read_examples(input_file):
169 | """Read a list of `InputExample`s from an input file."""
170 | examples = []
171 | unique_id = 0
172 | with open(input_file, "r") as reader:
173 | while True:
174 | line = tokenization.convert_to_unicode(reader.readline())
175 | if not line:
176 | break
177 | line = line.strip()
178 | text_a = None
179 | text_b = None
180 | m = re.match(r"^(.*) \|\|\| (.*)$", line)
181 | if m is None:
182 | text_a = line
183 | else:
184 | text_a = m.group(1)
185 | text_b = m.group(2)
186 | examples.append(
187 | InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b))
188 | unique_id += 1
189 | return examples
190 |
191 |
192 | def main():
193 | parser = argparse.ArgumentParser()
194 |
195 | ## Required parameters
196 | parser.add_argument("--input_file", default=None, type=str, required=True)
197 | parser.add_argument("--vocab_file", default=None, type=str, required=True,
198 | help="The vocabulary file that the BERT model was trained on.")
199 | parser.add_argument("--output_file", default=None, type=str, required=True)
200 | parser.add_argument("--bert_config_file", default=None, type=str, required=True,
201 | help="The config json file corresponding to the pre-trained BERT model. "
202 | "This specifies the model architecture.")
203 | parser.add_argument("--init_checkpoint", default=None, type=str, required=True,
204 | help="Initial checkpoint (usually from a pre-trained BERT model).")
205 |
206 | ## Other parameters
207 | parser.add_argument("--layers", default="-1,-2,-3,-4", type=str)
208 | parser.add_argument("--max_seq_length", default=128, type=int,
209 | help="The maximum total input sequence length after WordPiece tokenization. Sequences longer "
210 | "than this will be truncated, and sequences shorter than this will be padded.")
211 | parser.add_argument("--do_lower_case", default=True, action='store_true',
212 | help="Whether to lower case the input text. Should be True for uncased "
213 | "models and False for cased models.")
214 | parser.add_argument("--batch_size", default=32, type=int, help="Batch size for predictions.")
215 | parser.add_argument("--local_rank",
216 | type=int,
217 | default=-1,
218 | help = "local_rank for distributed training on gpus")
219 |
220 | args = parser.parse_args()
221 |
222 | if args.local_rank == -1 or args.no_cuda:
223 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
224 | n_gpu = torch.cuda.device_count()
225 | else:
226 | device = torch.device("cuda", args.local_rank)
227 | n_gpu = 1
228 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
229 | torch.distributed.init_process_group(backend='nccl')
230 | logger.info("device", device, "n_gpu", n_gpu, "distributed training", bool(args.local_rank != -1))
231 |
232 | layer_indexes = [int(x) for x in args.layers.split(",")]
233 |
234 | bert_config = BertConfig.from_json_file(args.bert_config_file)
235 |
236 | tokenizer = tokenization.FullTokenizer(
237 | vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
238 |
239 | examples = read_examples(args.input_file)
240 |
241 | features = convert_examples_to_features(
242 | examples=examples, seq_length=args.max_seq_length, tokenizer=tokenizer)
243 |
244 | unique_id_to_feature = {}
245 | for feature in features:
246 | unique_id_to_feature[feature.unique_id] = feature
247 |
248 | model = BertModel(bert_config)
249 | if args.init_checkpoint is not None:
250 | model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
251 | model.to(device)
252 |
253 | if args.local_rank != -1:
254 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
255 | output_device=args.local_rank)
256 | elif n_gpu > 1:
257 | model = torch.nn.DataParallel(model)
258 |
259 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
260 | all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
261 | all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
262 |
263 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_example_index)
264 | if args.local_rank == -1:
265 | eval_sampler = SequentialSampler(eval_data)
266 | else:
267 | eval_sampler = DistributedSampler(eval_data)
268 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.batch_size)
269 |
270 | model.eval()
271 | with open(args.output_file, "w", encoding='utf-8') as writer:
272 | for input_ids, input_mask, example_indices in eval_dataloader:
273 | input_ids = input_ids.to(device)
274 | input_mask = input_mask.to(device)
275 |
276 | all_encoder_layers, _ = model(input_ids, token_type_ids=None, attention_mask=input_mask)
277 | all_encoder_layers = all_encoder_layers
278 |
279 | for b, example_index in enumerate(example_indices):
280 | feature = features[example_index.item()]
281 | unique_id = int(feature.unique_id)
282 | # feature = unique_id_to_feature[unique_id]
283 | output_json = collections.OrderedDict()
284 | output_json["linex_index"] = unique_id
285 | all_out_features = []
286 | for (i, token) in enumerate(feature.tokens):
287 | all_layers = []
288 | for (j, layer_index) in enumerate(layer_indexes):
289 | layer_output = all_encoder_layers[int(layer_index)].detach().cpu().numpy()
290 | layer_output = layer_output[b]
291 | layers = collections.OrderedDict()
292 | layers["index"] = layer_index
293 | layers["values"] = [
294 | round(x.item(), 6) for x in layer_output[i]
295 | ]
296 | all_layers.append(layers)
297 | out_features = collections.OrderedDict()
298 | out_features["token"] = token
299 | out_features["layers"] = all_layers
300 | all_out_features.append(out_features)
301 | output_json["features"] = all_out_features
302 | writer.write(json.dumps(output_json) + "\n")
303 |
304 |
305 | if __name__ == "__main__":
306 | main()
307 |
--------------------------------------------------------------------------------
/bert/modeling.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """PyTorch BERT model."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import copy
22 | import json
23 | import math
24 | import six
25 | import torch
26 | import torch.nn as nn
27 | from torch.nn import CrossEntropyLoss
28 |
29 | def gelu(x):
30 | """Implementation of the gelu activation function.
31 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
32 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
33 | """
34 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
35 |
36 |
37 | class BertConfig(object):
38 | """Configuration class to store the configuration of a `BertModel`.
39 | """
40 | def __init__(self,
41 | vocab_size,
42 | hidden_size=768,
43 | num_hidden_layers=12,
44 | num_attention_heads=12,
45 | intermediate_size=3072,
46 | hidden_act="gelu",
47 | hidden_dropout_prob=0.1,
48 | attention_probs_dropout_prob=0.1,
49 | max_position_embeddings=512,
50 | type_vocab_size=16,
51 | initializer_range=0.02):
52 | """Constructs BertConfig.
53 |
54 | Args:
55 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
56 | hidden_size: Size of the encoder layers and the pooler layer.
57 | num_hidden_layers: Number of hidden layers in the Transformer encoder.
58 | num_attention_heads: Number of attention heads for each attention layer in
59 | the Transformer encoder.
60 | intermediate_size: The size of the "intermediate" (i.e., feed-forward)
61 | layer in the Transformer encoder.
62 | hidden_act: The non-linear activation function (function or string) in the
63 | encoder and pooler.
64 | hidden_dropout_prob: The dropout probabilitiy for all fully connected
65 | layers in the embeddings, encoder, and pooler.
66 | attention_probs_dropout_prob: The dropout ratio for the attention
67 | probabilities.
68 | max_position_embeddings: The maximum sequence length that this model might
69 | ever be used with. Typically set this to something large just in case
70 | (e.g., 512 or 1024 or 2048).
71 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into
72 | `BertModel`.
73 | initializer_range: The sttdev of the truncated_normal_initializer for
74 | initializing all weight matrices.
75 | """
76 | self.vocab_size = vocab_size
77 | self.hidden_size = hidden_size
78 | self.num_hidden_layers = num_hidden_layers
79 | self.num_attention_heads = num_attention_heads
80 | self.hidden_act = hidden_act
81 | self.intermediate_size = intermediate_size
82 | self.hidden_dropout_prob = hidden_dropout_prob
83 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
84 | self.max_position_embeddings = max_position_embeddings
85 | self.type_vocab_size = type_vocab_size
86 | self.initializer_range = initializer_range
87 |
88 | @classmethod
89 | def from_dict(cls, json_object):
90 | """Constructs a `BertConfig` from a Python dictionary of parameters."""
91 | config = BertConfig(vocab_size=None)
92 | for (key, value) in six.iteritems(json_object):
93 | config.__dict__[key] = value
94 | return config
95 |
96 | @classmethod
97 | def from_json_file(cls, json_file):
98 | """Constructs a `BertConfig` from a json file of parameters."""
99 | with open(json_file, "r") as reader:
100 | text = reader.read()
101 | return cls.from_dict(json.loads(text))
102 |
103 | def to_dict(self):
104 | """Serializes this instance to a Python dictionary."""
105 | output = copy.deepcopy(self.__dict__)
106 | return output
107 |
108 | def to_json_string(self):
109 | """Serializes this instance to a JSON string."""
110 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
111 |
112 |
113 | class BERTLayerNorm(nn.Module):
114 | def __init__(self, config, variance_epsilon=1e-12):
115 | """Construct a layernorm module in the TF style (epsilon inside the square root).
116 | """
117 | super(BERTLayerNorm, self).__init__()
118 | self.gamma = nn.Parameter(torch.ones(config.hidden_size))
119 | self.beta = nn.Parameter(torch.zeros(config.hidden_size))
120 | self.variance_epsilon = variance_epsilon
121 |
122 | def forward(self, x):
123 | u = x.mean(-1, keepdim=True)
124 | s = (x - u).pow(2).mean(-1, keepdim=True)
125 | x = (x - u) / torch.sqrt(s + self.variance_epsilon)
126 | return self.gamma * x + self.beta
127 |
128 | class BERTEmbeddings(nn.Module):
129 | def __init__(self, config):
130 | super(BERTEmbeddings, self).__init__()
131 | """Construct the embedding module from word, position and token_type embeddings.
132 | """
133 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
134 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
135 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
136 |
137 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
138 | # any TensorFlow checkpoint file
139 | self.LayerNorm = BERTLayerNorm(config)
140 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
141 |
142 | def forward(self, input_ids, token_type_ids=None):
143 | seq_length = input_ids.size(1)
144 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
145 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
146 | if token_type_ids is None:
147 | token_type_ids = torch.zeros_like(input_ids)
148 |
149 | words_embeddings = self.word_embeddings(input_ids)
150 | position_embeddings = self.position_embeddings(position_ids)
151 | token_type_embeddings = self.token_type_embeddings(token_type_ids)
152 |
153 | embeddings = words_embeddings + position_embeddings + token_type_embeddings
154 | embeddings = self.LayerNorm(embeddings)
155 | embeddings = self.dropout(embeddings)
156 | return embeddings
157 |
158 |
159 | class BERTSelfAttention(nn.Module):
160 | def __init__(self, config):
161 | super(BERTSelfAttention, self).__init__()
162 | if config.hidden_size % config.num_attention_heads != 0:
163 | raise ValueError(
164 | "The hidden size (%d) is not a multiple of the number of attention "
165 | "heads (%d)" % (config.hidden_size, config.num_attention_heads))
166 | self.num_attention_heads = config.num_attention_heads
167 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
168 | self.all_head_size = self.num_attention_heads * self.attention_head_size
169 |
170 | self.query = nn.Linear(config.hidden_size, self.all_head_size)
171 | self.key = nn.Linear(config.hidden_size, self.all_head_size)
172 | self.value = nn.Linear(config.hidden_size, self.all_head_size)
173 |
174 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
175 |
176 | def transpose_for_scores(self, x):
177 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
178 | x = x.view(*new_x_shape)
179 | return x.permute(0, 2, 1, 3)
180 |
181 | def forward(self, hidden_states, attention_mask):
182 | mixed_query_layer = self.query(hidden_states)
183 | mixed_key_layer = self.key(hidden_states)
184 | mixed_value_layer = self.value(hidden_states)
185 |
186 | query_layer = self.transpose_for_scores(mixed_query_layer)
187 | key_layer = self.transpose_for_scores(mixed_key_layer)
188 | value_layer = self.transpose_for_scores(mixed_value_layer)
189 |
190 | # Take the dot product between "query" and "key" to get the raw attention scores.
191 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
192 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
193 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
194 | attention_scores = attention_scores + attention_mask
195 |
196 | # Normalize the attention scores to probabilities.
197 | attention_probs = nn.Softmax(dim=-1)(attention_scores)
198 |
199 | # This is actually dropping out entire tokens to attend to, which might
200 | # seem a bit unusual, but is taken from the original Transformer paper.
201 | attention_probs = self.dropout(attention_probs)
202 |
203 | context_layer = torch.matmul(attention_probs, value_layer)
204 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
205 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
206 | context_layer = context_layer.view(*new_context_layer_shape)
207 | return context_layer
208 |
209 |
210 | class BERTSelfOutput(nn.Module):
211 | def __init__(self, config):
212 | super(BERTSelfOutput, self).__init__()
213 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
214 | self.LayerNorm = BERTLayerNorm(config)
215 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
216 |
217 | def forward(self, hidden_states, input_tensor):
218 | hidden_states = self.dense(hidden_states)
219 | hidden_states = self.dropout(hidden_states)
220 | hidden_states = self.LayerNorm(hidden_states + input_tensor)
221 | return hidden_states
222 |
223 |
224 | class BERTAttention(nn.Module):
225 | def __init__(self, config):
226 | super(BERTAttention, self).__init__()
227 | self.self = BERTSelfAttention(config)
228 | self.output = BERTSelfOutput(config)
229 |
230 | def forward(self, input_tensor, attention_mask):
231 | self_output = self.self(input_tensor, attention_mask)
232 | attention_output = self.output(self_output, input_tensor)
233 | return attention_output
234 |
235 |
236 | class BERTIntermediate(nn.Module):
237 | def __init__(self, config):
238 | super(BERTIntermediate, self).__init__()
239 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
240 | self.intermediate_act_fn = gelu
241 |
242 | def forward(self, hidden_states):
243 | hidden_states = self.dense(hidden_states)
244 | hidden_states = self.intermediate_act_fn(hidden_states)
245 | return hidden_states
246 |
247 |
248 | class BERTOutput(nn.Module):
249 | def __init__(self, config):
250 | super(BERTOutput, self).__init__()
251 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
252 | self.LayerNorm = BERTLayerNorm(config)
253 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
254 |
255 | def forward(self, hidden_states, input_tensor):
256 | hidden_states = self.dense(hidden_states)
257 | hidden_states = self.dropout(hidden_states)
258 | hidden_states = self.LayerNorm(hidden_states + input_tensor)
259 | return hidden_states
260 |
261 |
262 | class BERTLayer(nn.Module):
263 | def __init__(self, config):
264 | super(BERTLayer, self).__init__()
265 | self.attention = BERTAttention(config)
266 | self.intermediate = BERTIntermediate(config)
267 | self.output = BERTOutput(config)
268 |
269 | def forward(self, hidden_states, attention_mask):
270 | attention_output = self.attention(hidden_states, attention_mask)
271 | intermediate_output = self.intermediate(attention_output)
272 | layer_output = self.output(intermediate_output, attention_output)
273 | return layer_output
274 |
275 |
276 | class BERTEncoder(nn.Module):
277 | def __init__(self, config):
278 | super(BERTEncoder, self).__init__()
279 | layer = BERTLayer(config)
280 | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
281 |
282 | def forward(self, hidden_states, attention_mask):
283 | all_encoder_layers = []
284 | for layer_module in self.layer:
285 | hidden_states = layer_module(hidden_states, attention_mask)
286 | all_encoder_layers.append(hidden_states)
287 | return all_encoder_layers
288 |
289 |
290 | class BERTPooler(nn.Module):
291 | def __init__(self, config):
292 | super(BERTPooler, self).__init__()
293 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
294 | self.activation = nn.Tanh()
295 |
296 | def forward(self, hidden_states):
297 | # We "pool" the model by simply taking the hidden state corresponding
298 | # to the first token.
299 | first_token_tensor = hidden_states[:, 0]
300 | pooled_output = self.dense(first_token_tensor)
301 | pooled_output = self.activation(pooled_output)
302 | return pooled_output
303 |
304 |
305 | class BertModel(nn.Module):
306 | """BERT model ("Bidirectional Embedding Representations from a Transformer").
307 |
308 | Example usage:
309 | ```python
310 | # Already been converted into WordPiece token ids
311 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
312 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
313 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
314 |
315 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512,
316 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
317 |
318 | model = modeling.BertModel(config=config)
319 | all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
320 | ```
321 | """
322 | def __init__(self, config: BertConfig):
323 | """Constructor for BertModel.
324 |
325 | Args:
326 | config: `BertConfig` instance.
327 | """
328 | super(BertModel, self).__init__()
329 | self.embeddings = BERTEmbeddings(config)
330 | self.encoder = BERTEncoder(config)
331 | self.pooler = BERTPooler(config)
332 |
333 | def forward(self, input_ids, token_type_ids=None, attention_mask=None):
334 | if attention_mask is None:
335 | attention_mask = torch.ones_like(input_ids)
336 | if token_type_ids is None:
337 | token_type_ids = torch.zeros_like(input_ids)
338 |
339 | # We create a 3D attention mask from a 2D tensor mask.
340 | # Sizes are [batch_size, 1, 1, to_seq_length]
341 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
342 | # this attention mask is more simple than the triangular masking of causal attention
343 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
344 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
345 |
346 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
347 | # masked positions, this operation will create a tensor which is 0.0 for
348 | # positions we want to attend and -10000.0 for masked positions.
349 | # Since we are adding it to the raw scores before the softmax, this is
350 | # effectively the same as removing these entirely.
351 | extended_attention_mask = extended_attention_mask.float()
352 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
353 |
354 | embedding_output = self.embeddings(input_ids, token_type_ids)
355 | all_encoder_layers = self.encoder(embedding_output, extended_attention_mask)
356 | sequence_output = all_encoder_layers[-1]
357 | pooled_output = self.pooler(sequence_output)
358 | return all_encoder_layers, pooled_output
359 |
360 | class BertForSequenceClassification(nn.Module):
361 | """BERT model for classification.
362 | This module is composed of the BERT model with a linear layer on top of
363 | the pooled output.
364 |
365 | Example usage:
366 | ```python
367 | # Already been converted into WordPiece token ids
368 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
369 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
370 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
371 |
372 | config = BertConfig(vocab_size=32000, hidden_size=512,
373 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
374 |
375 | num_labels = 2
376 |
377 | model = BertForSequenceClassification(config, num_labels)
378 | logits = model(input_ids, token_type_ids, input_mask)
379 | ```
380 | """
381 | def __init__(self, config, num_labels):
382 | super(BertForSequenceClassification, self).__init__()
383 | self.bert = BertModel(config)
384 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
385 | self.classifier = nn.Linear(config.hidden_size, num_labels)
386 |
387 | def init_weights(module):
388 | if isinstance(module, (nn.Linear, nn.Embedding)):
389 | # Slightly different from the TF version which uses truncated_normal for initialization
390 | # cf https://github.com/pytorch/pytorch/pull/5617
391 | module.weight.data.normal_(mean=0.0, std=config.initializer_range)
392 | elif isinstance(module, BERTLayerNorm):
393 | module.beta.data.normal_(mean=0.0, std=config.initializer_range)
394 | module.gamma.data.normal_(mean=0.0, std=config.initializer_range)
395 | if isinstance(module, nn.Linear):
396 | module.bias.data.zero_()
397 | self.apply(init_weights)
398 |
399 | def forward(self, input_ids, token_type_ids, attention_mask, labels=None, n_class=1):
400 | seq_length = input_ids.size(2)
401 | _, pooled_output = self.bert(input_ids.view(-1,seq_length),
402 | token_type_ids.view(-1,seq_length),
403 | attention_mask.view(-1,seq_length))
404 | pooled_output = self.dropout(pooled_output)
405 | logits = self.classifier(pooled_output)
406 | logits = logits.view(-1, n_class)
407 |
408 | if labels is not None:
409 | loss_fct = CrossEntropyLoss()
410 | labels = labels.view(-1)
411 | loss = loss_fct(logits, labels)
412 | return loss, logits
413 | else:
414 | return logits
415 |
416 | class BertForQuestionAnswering(nn.Module): #FIXME (kai): not modified accordingly yet
417 | """BERT model for Question Answering (span extraction).
418 | This module is composed of the BERT model with a linear layer on top of
419 | the sequence output that computes start_logits and end_logits
420 |
421 | Example usage:
422 | ```python
423 | # Already been converted into WordPiece token ids
424 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
425 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
426 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
427 |
428 | config = BertConfig(vocab_size=32000, hidden_size=512,
429 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
430 |
431 | model = BertForQuestionAnswering(config)
432 | start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
433 | ```
434 | """
435 | def __init__(self, config):
436 | super(BertForQuestionAnswering, self).__init__()
437 | self.bert = BertModel(config)
438 | # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
439 | # self.dropout = nn.Dropout(config.hidden_dropout_prob)
440 | self.qa_outputs = nn.Linear(config.hidden_size, 2)
441 |
442 | def init_weights(module):
443 | if isinstance(module, (nn.Linear, nn.Embedding)):
444 | # Slightly different from the TF version which uses truncated_normal for initialization
445 | # cf https://github.com/pytorch/pytorch/pull/5617
446 | module.weight.data.normal_(mean=0.0, std=config.initializer_range)
447 | elif isinstance(module, BERTLayerNorm):
448 | module.beta.data.normal_(mean=0.0, std=config.initializer_range)
449 | module.gamma.data.normal_(mean=0.0, std=config.initializer_range)
450 | if isinstance(module, nn.Linear):
451 | module.bias.data.zero_()
452 | self.apply(init_weights)
453 |
454 | def forward(self, input_ids, token_type_ids, attention_mask, start_positions=None, end_positions=None):
455 | all_encoder_layers, _ = self.bert(input_ids, token_type_ids, attention_mask)
456 | sequence_output = all_encoder_layers[-1]
457 | logits = self.qa_outputs(sequence_output)
458 | start_logits, end_logits = logits.split(1, dim=-1)
459 | start_logits = start_logits.squeeze(-1)
460 | end_logits = end_logits.squeeze(-1)
461 |
462 | if start_positions is not None and end_positions is not None:
463 | # If we are on multi-GPU, split add a dimension
464 | if len(start_positions.size()) > 1:
465 | start_positions = start_positions.squeeze(-1)
466 | if len(end_positions.size()) > 1:
467 | end_positions = end_positions.squeeze(-1)
468 | # sometimes the start/end positions are outside our model inputs, we ignore these terms
469 | ignored_index = start_logits.size(1)
470 | start_positions.clamp_(0, ignored_index)
471 | end_positions.clamp_(0, ignored_index)
472 |
473 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
474 | start_loss = loss_fct(start_logits, start_positions)
475 | end_loss = loss_fct(end_logits, end_positions)
476 | total_loss = (start_loss + end_loss) / 2
477 | return total_loss
478 | else:
479 | return start_logits, end_logits
480 |
--------------------------------------------------------------------------------
/bert/optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """PyTorch optimization for BERT model."""
16 |
17 | import math
18 | import torch
19 | from torch.optim import Optimizer
20 | from torch.nn.utils import clip_grad_norm_
21 |
22 | def warmup_cosine(x, warmup=0.002):
23 | if x < warmup:
24 | return x/warmup
25 | return 0.5 * (1.0 + torch.cos(math.pi * x))
26 |
27 | def warmup_constant(x, warmup=0.002):
28 | if x < warmup:
29 | return x/warmup
30 | return 1.0
31 |
32 | def warmup_linear(x, warmup=0.002):
33 | if x < warmup:
34 | return x/warmup
35 | return 1.0 - x
36 |
37 | SCHEDULES = {
38 | 'warmup_cosine':warmup_cosine,
39 | 'warmup_constant':warmup_constant,
40 | 'warmup_linear':warmup_linear,
41 | }
42 |
43 |
44 | class BERTAdam(Optimizer):
45 | """Implements BERT version of Adam algorithm with weight decay fix (and no ).
46 | Params:
47 | lr: learning rate
48 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
49 | t_total: total number of training steps for the learning
50 | rate schedule, -1 means constant learning rate. Default: -1
51 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear'
52 | b1: Adams b1. Default: 0.9
53 | b2: Adams b2. Default: 0.999
54 | e: Adams epsilon. Default: 1e-6
55 | weight_decay_rate: Weight decay. Default: 0.01
56 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
57 | """
58 | def __init__(self, params, lr, warmup=-1, t_total=-1, schedule='warmup_linear',
59 | b1=0.9, b2=0.999, e=1e-6, weight_decay_rate=0.01,
60 | max_grad_norm=1.0):
61 | if not lr >= 0.0:
62 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
63 | if schedule not in SCHEDULES:
64 | raise ValueError("Invalid schedule parameter: {}".format(schedule))
65 | if not 0.0 <= warmup < 1.0 and not warmup == -1:
66 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
67 | if not 0.0 <= b1 < 1.0:
68 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
69 | if not 0.0 <= b2 < 1.0:
70 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
71 | if not e >= 0.0:
72 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
73 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
74 | b1=b1, b2=b2, e=e, weight_decay_rate=weight_decay_rate,
75 | max_grad_norm=max_grad_norm)
76 | super(BERTAdam, self).__init__(params, defaults)
77 |
78 | def get_lr(self):
79 | lr = []
80 | for group in self.param_groups:
81 | for p in group['params']:
82 | state = self.state[p]
83 | if len(state) == 0:
84 | return [0]
85 | if group['t_total'] != -1:
86 | schedule_fct = SCHEDULES[group['schedule']]
87 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
88 | else:
89 | lr_scheduled = group['lr']
90 | lr.append(lr_scheduled)
91 | return lr
92 |
93 | def step(self, closure=None):
94 | """Performs a single optimization step.
95 |
96 | Arguments:
97 | closure (callable, optional): A closure that reevaluates the model
98 | and returns the loss.
99 | """
100 | loss = None
101 | if closure is not None:
102 | loss = closure()
103 |
104 | for group in self.param_groups:
105 | for p in group['params']:
106 | if p.grad is None:
107 | continue
108 | grad = p.grad.data
109 | if grad.is_sparse:
110 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
111 |
112 | state = self.state[p]
113 |
114 | # State initialization
115 | if len(state) == 0:
116 | state['step'] = 0
117 | # Exponential moving average of gradient values
118 | state['next_m'] = torch.zeros_like(p.data)
119 | # Exponential moving average of squared gradient values
120 | state['next_v'] = torch.zeros_like(p.data)
121 |
122 | next_m, next_v = state['next_m'], state['next_v']
123 | beta1, beta2 = group['b1'], group['b2']
124 |
125 | # Add grad clipping
126 | if group['max_grad_norm'] > 0:
127 | clip_grad_norm_(p, group['max_grad_norm'])
128 |
129 | # Decay the first and second moment running average coefficient
130 | # In-place operations to update the averages at the same time
131 | next_m.mul_(beta1).add_(1 - beta1, grad)
132 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
133 | update = next_m / (next_v.sqrt() + group['e'])
134 |
135 | # Just adding the square of the weights to the loss function is *not*
136 | # the correct way of using L2 regularization/weight decay with Adam,
137 | # since that will interact with the m and v parameters in strange ways.
138 | #
139 | # Instead we want ot decay the weights in a manner that doesn't interact
140 | # with the m/v parameters. This is equivalent to adding the square
141 | # of the weights to the loss with plain (non-momentum) SGD.
142 | if group['weight_decay_rate'] > 0.0:
143 | update += group['weight_decay_rate'] * p.data
144 |
145 | if group['t_total'] != -1:
146 | schedule_fct = SCHEDULES[group['schedule']]
147 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
148 | else:
149 | lr_scheduled = group['lr']
150 |
151 | update_with_lr = lr_scheduled * update
152 | p.data.add_(-update_with_lr)
153 |
154 | state['step'] += 1
155 |
156 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
157 | # bias_correction1 = 1 - beta1 ** state['step']
158 | # bias_correction2 = 1 - beta2 ** state['step']
159 |
160 | return loss
161 |
--------------------------------------------------------------------------------
/bert/run_classifier.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """BERT finetuning runner."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import csv
22 | import os
23 | import logging
24 | import argparse
25 | import random
26 | from tqdm import tqdm, trange
27 |
28 | import numpy as np
29 | import torch
30 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
31 | from torch.utils.data.distributed import DistributedSampler
32 |
33 | import tokenization
34 | from modeling import BertConfig, BertForSequenceClassification
35 | from optimization import BERTAdam
36 |
37 | import json
38 |
39 | n_class = 4
40 | reverse_order = False
41 | sa_step = False
42 |
43 |
44 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
45 | datefmt = '%m/%d/%Y %H:%M:%S',
46 | level = logging.INFO)
47 | logger = logging.getLogger(__name__)
48 |
49 |
50 | class InputExample(object):
51 | """A single training/test example for simple sequence classification."""
52 |
53 | def __init__(self, guid, text_a, text_b=None, label=None, text_c=None):
54 | """Constructs a InputExample.
55 |
56 | Args:
57 | guid: Unique id for the example.
58 | text_a: string. The untokenized text of the first sequence. For single
59 | sequence tasks, only this sequence must be specified.
60 | text_b: (Optional) string. The untokenized text of the second sequence.
61 | Only must be specified for sequence pair tasks.
62 | label: (Optional) string. The label of the example. This should be
63 | specified for train and dev examples, but not for test examples.
64 | """
65 | self.guid = guid
66 | self.text_a = text_a
67 | self.text_b = text_b
68 | self.text_c = text_c
69 | self.label = label
70 |
71 |
72 | class InputFeatures(object):
73 | """A single set of features of data."""
74 |
75 | def __init__(self, input_ids, input_mask, segment_ids, label_id):
76 | self.input_ids = input_ids
77 | self.input_mask = input_mask
78 | self.segment_ids = segment_ids
79 | self.label_id = label_id
80 |
81 |
82 | class DataProcessor(object):
83 | """Base class for data converters for sequence classification data sets."""
84 |
85 | def get_train_examples(self, data_dir):
86 | """Gets a collection of `InputExample`s for the train set."""
87 | raise NotImplementedError()
88 |
89 | def get_dev_examples(self, data_dir):
90 | """Gets a collection of `InputExample`s for the dev set."""
91 | raise NotImplementedError()
92 |
93 | def get_labels(self):
94 | """Gets the list of labels for this data set."""
95 | raise NotImplementedError()
96 |
97 | @classmethod
98 | def _read_tsv(cls, input_file, quotechar=None):
99 | """Reads a tab separated value file."""
100 | with open(input_file, "r") as f:
101 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
102 | lines = []
103 | for line in reader:
104 | lines.append(line)
105 | return lines
106 |
107 |
108 | class c3Processor(DataProcessor):
109 | def __init__(self):
110 | random.seed(42)
111 | self.D = [[], [], []]
112 |
113 | for sid in range(3):
114 | data = []
115 | for subtask in ["d", "m"]:
116 | with open("data/c3-"+subtask+"-"+["train.json", "dev.json", "test.json"][sid], "r", encoding="utf8") as f:
117 | data += json.load(f)
118 | if sid == 0:
119 | random.shuffle(data)
120 | for i in range(len(data)):
121 | for j in range(len(data[i][1])):
122 | d = ['\n'.join(data[i][0]).lower(), data[i][1][j]["question"].lower()]
123 | for k in range(len(data[i][1][j]["choice"])):
124 | d += [data[i][1][j]["choice"][k].lower()]
125 | for k in range(len(data[i][1][j]["choice"]), 4):
126 | d += ['']
127 | d += [data[i][1][j]["answer"].lower()]
128 | self.D[sid] += [d]
129 |
130 | def get_train_examples(self, data_dir):
131 | """See base class."""
132 | return self._create_examples(
133 | self.D[0], "train")
134 |
135 | def get_test_examples(self, data_dir):
136 | """See base class."""
137 | return self._create_examples(
138 | self.D[2], "test")
139 |
140 | def get_dev_examples(self, data_dir):
141 | """See base class."""
142 | return self._create_examples(
143 | self.D[1], "dev")
144 |
145 | def get_labels(self):
146 | """See base class."""
147 | return ["0", "1", "2", "3"]
148 |
149 | def _create_examples(self, data, set_type):
150 | """Creates examples for the training and dev sets."""
151 | examples = []
152 | for (i, d) in enumerate(data):
153 | for k in range(4):
154 | if data[i][2+k] == data[i][6]:
155 | answer = str(k)
156 |
157 | label = tokenization.convert_to_unicode(answer)
158 |
159 | for k in range(4):
160 | guid = "%s-%s-%s" % (set_type, i, k)
161 | text_a = tokenization.convert_to_unicode(data[i][0])
162 | text_b = tokenization.convert_to_unicode(data[i][k+2])
163 | text_c = tokenization.convert_to_unicode(data[i][1])
164 | examples.append(
165 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, text_c=text_c))
166 |
167 | return examples
168 |
169 |
170 |
171 | def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
172 | """Loads a data file into a list of `InputBatch`s."""
173 |
174 | print("#examples", len(examples))
175 |
176 | label_map = {}
177 | for (i, label) in enumerate(label_list):
178 | label_map[label] = i
179 |
180 | features = [[]]
181 | for (ex_index, example) in enumerate(examples):
182 | tokens_a = tokenizer.tokenize(example.text_a)
183 |
184 | tokens_b = tokenizer.tokenize(example.text_b)
185 |
186 | tokens_c = tokenizer.tokenize(example.text_c)
187 |
188 | _truncate_seq_tuple(tokens_a, tokens_b, tokens_c, max_seq_length - 4)
189 | tokens_b = tokens_c + ["[SEP]"] + tokens_b
190 |
191 | tokens = []
192 | segment_ids = []
193 | tokens.append("[CLS]")
194 | segment_ids.append(0)
195 | for token in tokens_a:
196 | tokens.append(token)
197 | segment_ids.append(0)
198 | tokens.append("[SEP]")
199 | segment_ids.append(0)
200 |
201 | if tokens_b:
202 | for token in tokens_b:
203 | tokens.append(token)
204 | segment_ids.append(1)
205 | tokens.append("[SEP]")
206 | segment_ids.append(1)
207 |
208 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
209 |
210 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
211 | # tokens are attended to.
212 | input_mask = [1] * len(input_ids)
213 |
214 | # Zero-pad up to the sequence length.
215 | while len(input_ids) < max_seq_length:
216 | input_ids.append(0)
217 | input_mask.append(0)
218 | segment_ids.append(0)
219 |
220 | assert len(input_ids) == max_seq_length
221 | assert len(input_mask) == max_seq_length
222 | assert len(segment_ids) == max_seq_length
223 |
224 | label_id = label_map[example.label]
225 | if ex_index < 5:
226 | logger.info("*** Example ***")
227 | logger.info("guid: %s" % (example.guid))
228 | logger.info("tokens: %s" % " ".join(
229 | [tokenization.printable_text(x) for x in tokens]))
230 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
231 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
232 | logger.info(
233 | "segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
234 | logger.info("label: %s (id = %d)" % (example.label, label_id))
235 |
236 | features[-1].append(
237 | InputFeatures(
238 | input_ids=input_ids,
239 | input_mask=input_mask,
240 | segment_ids=segment_ids,
241 | label_id=label_id))
242 | if len(features[-1]) == n_class:
243 | features.append([])
244 |
245 | if len(features[-1]) == 0:
246 | features = features[:-1]
247 | print('#features', len(features))
248 | return features
249 |
250 |
251 |
252 | def _truncate_seq_pair(tokens_a, tokens_b, max_length):
253 | """Truncates a sequence pair in place to the maximum length."""
254 |
255 | # This is a simple heuristic which will always truncate the longer sequence
256 | # one token at a time. This makes more sense than truncating an equal percent
257 | # of tokens from each, since if one sequence is very short then each token
258 | # that's truncated likely contains more information than a longer sequence.
259 | while True:
260 | total_length = len(tokens_a) + len(tokens_b)
261 | if total_length <= max_length:
262 | break
263 | if len(tokens_a) > len(tokens_b):
264 | tokens_a.pop()
265 | else:
266 | tokens_b.pop()
267 |
268 |
269 | def _truncate_seq_tuple(tokens_a, tokens_b, tokens_c, max_length):
270 | """Truncates a sequence tuple in place to the maximum length."""
271 |
272 | # This is a simple heuristic which will always truncate the longer sequence
273 | # one token at a time. This makes more sense than truncating an equal percent
274 | # of tokens from each, since if one sequence is very short then each token
275 | # that's truncated likely contains more information than a longer sequence.
276 | while True:
277 | total_length = len(tokens_a) + len(tokens_b) + len(tokens_c)
278 | if total_length <= max_length:
279 | break
280 | if len(tokens_a) >= len(tokens_b) and len(tokens_a) >= len(tokens_c):
281 | tokens_a.pop()
282 | elif len(tokens_b) >= len(tokens_a) and len(tokens_b) >= len(tokens_c):
283 | tokens_b.pop()
284 | else:
285 | tokens_c.pop()
286 |
287 |
288 | def accuracy(out, labels):
289 | outputs = np.argmax(out, axis=1)
290 | return np.sum(outputs==labels)
291 |
292 | def main():
293 | parser = argparse.ArgumentParser()
294 |
295 | ## Required parameters
296 | parser.add_argument("--data_dir",
297 | default=None,
298 | type=str,
299 | required=True,
300 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
301 | parser.add_argument("--bert_config_file",
302 | default=None,
303 | type=str,
304 | required=True,
305 | help="The config json file corresponding to the pre-trained BERT model. \n"
306 | "This specifies the model architecture.")
307 | parser.add_argument("--task_name",
308 | default=None,
309 | type=str,
310 | required=True,
311 | help="The name of the task to train.")
312 | parser.add_argument("--vocab_file",
313 | default=None,
314 | type=str,
315 | required=True,
316 | help="The vocabulary file that the BERT model was trained on.")
317 | parser.add_argument("--output_dir",
318 | default=None,
319 | type=str,
320 | required=True,
321 | help="The output directory where the model checkpoints will be written.")
322 |
323 | ## Other parameters
324 | parser.add_argument("--init_checkpoint",
325 | default=None,
326 | type=str,
327 | help="Initial checkpoint (usually from a pre-trained BERT model).")
328 | parser.add_argument("--do_lower_case",
329 | default=False,
330 | action='store_true',
331 | help="Whether to lower case the input text. True for uncased models, False for cased models.")
332 | parser.add_argument("--max_seq_length",
333 | default=128,
334 | type=int,
335 | help="The maximum total input sequence length after WordPiece tokenization. \n"
336 | "Sequences longer than this will be truncated, and sequences shorter \n"
337 | "than this will be padded.")
338 | parser.add_argument("--do_train",
339 | default=False,
340 | action='store_true',
341 | help="Whether to run training.")
342 | parser.add_argument("--do_eval",
343 | default=False,
344 | action='store_true',
345 | help="Whether to run eval on the dev set.")
346 | parser.add_argument("--train_batch_size",
347 | default=32,
348 | type=int,
349 | help="Total batch size for training.")
350 | parser.add_argument("--eval_batch_size",
351 | default=8,
352 | type=int,
353 | help="Total batch size for eval.")
354 | parser.add_argument("--learning_rate",
355 | default=5e-5,
356 | type=float,
357 | help="The initial learning rate for Adam.")
358 | parser.add_argument("--num_train_epochs",
359 | default=3.0,
360 | type=float,
361 | help="Total number of training epochs to perform.")
362 | parser.add_argument("--warmup_proportion",
363 | default=0.1,
364 | type=float,
365 | help="Proportion of training to perform linear learning rate warmup for. "
366 | "E.g., 0.1 = 10%% of training.")
367 | parser.add_argument("--save_checkpoints_steps",
368 | default=1000,
369 | type=int,
370 | help="How often to save the model checkpoint.")
371 | parser.add_argument("--no_cuda",
372 | default=False,
373 | action='store_true',
374 | help="Whether not to use CUDA when available")
375 | parser.add_argument("--local_rank",
376 | type=int,
377 | default=-1,
378 | help="local_rank for distributed training on gpus")
379 | parser.add_argument('--seed',
380 | type=int,
381 | default=42,
382 | help="random seed for initialization")
383 | parser.add_argument('--gradient_accumulation_steps',
384 | type=int,
385 | default=1,
386 | help="Number of updates steps to accumualte before performing a backward/update pass.")
387 | args = parser.parse_args()
388 |
389 | processors = {
390 | "c3": c3Processor,
391 | }
392 |
393 | if args.local_rank == -1 or args.no_cuda:
394 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
395 | n_gpu = torch.cuda.device_count()
396 | else:
397 | device = torch.device("cuda", args.local_rank)
398 | n_gpu = 1
399 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
400 | torch.distributed.init_process_group(backend='nccl')
401 | logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1))
402 |
403 | if args.gradient_accumulation_steps < 1:
404 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
405 | args.gradient_accumulation_steps))
406 |
407 | args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)
408 |
409 | random.seed(args.seed)
410 | np.random.seed(args.seed)
411 | torch.manual_seed(args.seed)
412 | if n_gpu > 0:
413 | torch.cuda.manual_seed_all(args.seed)
414 |
415 | if not args.do_train and not args.do_eval:
416 | raise ValueError("At least one of `do_train` or `do_eval` must be True.")
417 |
418 | bert_config = BertConfig.from_json_file(args.bert_config_file)
419 |
420 | if args.max_seq_length > bert_config.max_position_embeddings:
421 | raise ValueError(
422 | "Cannot use sequence length {} because the BERT model was only trained up to sequence length {}".format(
423 | args.max_seq_length, bert_config.max_position_embeddings))
424 |
425 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
426 | if args.do_train:
427 | raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
428 | else:
429 | os.makedirs(args.output_dir, exist_ok=True)
430 |
431 | task_name = args.task_name.lower()
432 |
433 | if task_name not in processors:
434 | raise ValueError("Task not found: %s" % (task_name))
435 |
436 | processor = processors[task_name]()
437 | label_list = processor.get_labels()
438 |
439 | tokenizer = tokenization.FullTokenizer(
440 | vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
441 |
442 | train_examples = None
443 | num_train_steps = None
444 | if args.do_train:
445 | train_examples = processor.get_train_examples(args.data_dir)
446 | num_train_steps = int(
447 | len(train_examples) / n_class / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
448 |
449 | model = BertForSequenceClassification(bert_config, 1 if n_class > 1 else len(label_list))
450 | if args.init_checkpoint is not None:
451 | model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
452 | model.to(device)
453 |
454 | if args.local_rank != -1:
455 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
456 | output_device=args.local_rank)
457 | elif n_gpu > 1:
458 | model = torch.nn.DataParallel(model)
459 |
460 | no_decay = ['bias', 'gamma', 'beta']
461 | optimizer_parameters = [
462 | {'params': [p for n, p in model.named_parameters() if n not in no_decay], 'weight_decay_rate': 0.01},
463 | {'params': [p for n, p in model.named_parameters() if n in no_decay], 'weight_decay_rate': 0.0}
464 | ]
465 |
466 | optimizer = BERTAdam(optimizer_parameters,
467 | lr=args.learning_rate,
468 | warmup=args.warmup_proportion,
469 | t_total=num_train_steps)
470 |
471 | global_step = 0
472 |
473 | if args.do_eval:
474 | eval_examples = processor.get_dev_examples(args.data_dir)
475 | eval_features = convert_examples_to_features(
476 | eval_examples, label_list, args.max_seq_length, tokenizer)
477 |
478 | input_ids = []
479 | input_mask = []
480 | segment_ids = []
481 | label_id = []
482 |
483 | for f in eval_features:
484 | input_ids.append([])
485 | input_mask.append([])
486 | segment_ids.append([])
487 | for i in range(n_class):
488 | input_ids[-1].append(f[i].input_ids)
489 | input_mask[-1].append(f[i].input_mask)
490 | segment_ids[-1].append(f[i].segment_ids)
491 | label_id.append([f[0].label_id])
492 |
493 | all_input_ids = torch.tensor(input_ids, dtype=torch.long)
494 | all_input_mask = torch.tensor(input_mask, dtype=torch.long)
495 | all_segment_ids = torch.tensor(segment_ids, dtype=torch.long)
496 | all_label_ids = torch.tensor(label_id, dtype=torch.long)
497 |
498 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
499 | if args.local_rank == -1:
500 | eval_sampler = SequentialSampler(eval_data)
501 | else:
502 | eval_sampler = DistributedSampler(eval_data)
503 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
504 |
505 |
506 | if args.do_train:
507 | best_accuracy = 0
508 |
509 | train_features = convert_examples_to_features(
510 | train_examples, label_list, args.max_seq_length, tokenizer)
511 | logger.info("***** Running training *****")
512 | logger.info(" Num examples = %d", len(train_examples))
513 | logger.info(" Batch size = %d", args.train_batch_size)
514 | logger.info(" Num steps = %d", num_train_steps)
515 |
516 | input_ids = []
517 | input_mask = []
518 | segment_ids = []
519 | label_id = []
520 | for f in train_features:
521 | input_ids.append([])
522 | input_mask.append([])
523 | segment_ids.append([])
524 | for i in range(n_class):
525 | input_ids[-1].append(f[i].input_ids)
526 | input_mask[-1].append(f[i].input_mask)
527 | segment_ids[-1].append(f[i].segment_ids)
528 | label_id.append([f[0].label_id])
529 |
530 | all_input_ids = torch.tensor(input_ids, dtype=torch.long)
531 | all_input_mask = torch.tensor(input_mask, dtype=torch.long)
532 | all_segment_ids = torch.tensor(segment_ids, dtype=torch.long)
533 | all_label_ids = torch.tensor(label_id, dtype=torch.long)
534 |
535 | train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
536 | if args.local_rank == -1:
537 | train_sampler = RandomSampler(train_data)
538 | else:
539 | train_sampler = DistributedSampler(train_data)
540 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
541 |
542 | for _ in trange(int(args.num_train_epochs), desc="Epoch"):
543 | model.train()
544 | tr_loss = 0
545 | nb_tr_examples, nb_tr_steps = 0, 0
546 | for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
547 | batch = tuple(t.to(device) for t in batch)
548 | input_ids, input_mask, segment_ids, label_ids = batch
549 | loss, _ = model(input_ids, segment_ids, input_mask, label_ids, n_class)
550 | if n_gpu > 1:
551 | loss = loss.mean() # mean() to average on multi-gpu.
552 | if args.gradient_accumulation_steps > 1:
553 | loss = loss / args.gradient_accumulation_steps
554 | loss.backward()
555 | tr_loss += loss.item()
556 | nb_tr_examples += input_ids.size(0)
557 | nb_tr_steps += 1
558 | if (step + 1) % args.gradient_accumulation_steps == 0:
559 | optimizer.step() # We have accumulated enought gradients
560 | model.zero_grad()
561 | global_step += 1
562 |
563 | model.eval()
564 | eval_loss, eval_accuracy = 0, 0
565 | nb_eval_steps, nb_eval_examples = 0, 0
566 | logits_all = []
567 | for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
568 | input_ids = input_ids.to(device)
569 | input_mask = input_mask.to(device)
570 | segment_ids = segment_ids.to(device)
571 | label_ids = label_ids.to(device)
572 |
573 | with torch.no_grad():
574 | tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids, n_class)
575 |
576 | logits = logits.detach().cpu().numpy()
577 | label_ids = label_ids.to('cpu').numpy()
578 | for i in range(len(logits)):
579 | logits_all += [logits[i]]
580 |
581 | tmp_eval_accuracy = accuracy(logits, label_ids.reshape(-1))
582 |
583 | eval_loss += tmp_eval_loss.mean().item()
584 | eval_accuracy += tmp_eval_accuracy
585 |
586 | nb_eval_examples += input_ids.size(0)
587 | nb_eval_steps += 1
588 |
589 | eval_loss = eval_loss / nb_eval_steps
590 | eval_accuracy = eval_accuracy / nb_eval_examples
591 |
592 | if args.do_train:
593 | result = {'eval_loss': eval_loss,
594 | 'eval_accuracy': eval_accuracy,
595 | 'global_step': global_step,
596 | 'loss': tr_loss/nb_tr_steps}
597 | else:
598 | result = {'eval_loss': eval_loss,
599 | 'eval_accuracy': eval_accuracy}
600 |
601 | logger.info("***** Eval results *****")
602 | for key in sorted(result.keys()):
603 | logger.info(" %s = %s", key, str(result[key]))
604 |
605 | if eval_accuracy >= best_accuracy:
606 | torch.save(model.state_dict(), os.path.join(args.output_dir, "model_best.pt"))
607 | best_accuracy = eval_accuracy
608 |
609 | model.load_state_dict(torch.load(os.path.join(args.output_dir, "model_best.pt")))
610 | torch.save(model.state_dict(), os.path.join(args.output_dir, "model.pt"))
611 |
612 | model.load_state_dict(torch.load(os.path.join(args.output_dir, "model.pt")))
613 |
614 | if args.do_eval:
615 | logger.info("***** Running evaluation *****")
616 | logger.info(" Num examples = %d", len(eval_examples))
617 | logger.info(" Batch size = %d", args.eval_batch_size)
618 |
619 | model.eval()
620 | eval_loss, eval_accuracy = 0, 0
621 | nb_eval_steps, nb_eval_examples = 0, 0
622 | logits_all = []
623 | for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
624 | input_ids = input_ids.to(device)
625 | input_mask = input_mask.to(device)
626 | segment_ids = segment_ids.to(device)
627 | label_ids = label_ids.to(device)
628 |
629 | with torch.no_grad():
630 | tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids, n_class)
631 |
632 | logits = logits.detach().cpu().numpy()
633 | label_ids = label_ids.to('cpu').numpy()
634 | for i in range(len(logits)):
635 | logits_all += [logits[i]]
636 |
637 | tmp_eval_accuracy = accuracy(logits, label_ids.reshape(-1))
638 |
639 | eval_loss += tmp_eval_loss.mean().item()
640 | eval_accuracy += tmp_eval_accuracy
641 |
642 | nb_eval_examples += input_ids.size(0)
643 | nb_eval_steps += 1
644 |
645 | eval_loss = eval_loss / nb_eval_steps
646 | eval_accuracy = eval_accuracy / nb_eval_examples
647 |
648 | if args.do_train:
649 | result = {'eval_loss': eval_loss,
650 | 'eval_accuracy': eval_accuracy,
651 | 'global_step': global_step,
652 | 'loss': tr_loss/nb_tr_steps}
653 | else:
654 | result = {'eval_loss': eval_loss,
655 | 'eval_accuracy': eval_accuracy}
656 |
657 |
658 | output_eval_file = os.path.join(args.output_dir, "eval_results_dev.txt")
659 | with open(output_eval_file, "w") as writer:
660 | logger.info("***** Eval results *****")
661 | for key in sorted(result.keys()):
662 | logger.info(" %s = %s", key, str(result[key]))
663 | writer.write("%s = %s\n" % (key, str(result[key])))
664 | output_eval_file = os.path.join(args.output_dir, "logits_dev.txt")
665 | with open(output_eval_file, "w") as f:
666 | for i in range(len(logits_all)):
667 | for j in range(len(logits_all[i])):
668 | f.write(str(logits_all[i][j]))
669 | if j == len(logits_all[i])-1:
670 | f.write("\n")
671 | else:
672 | f.write(" ")
673 |
674 |
675 | eval_examples = processor.get_test_examples(args.data_dir)
676 | eval_features = convert_examples_to_features(
677 | eval_examples, label_list, args.max_seq_length, tokenizer)
678 |
679 | logger.info("***** Running evaluation *****")
680 | logger.info(" Num examples = %d", len(eval_examples))
681 | logger.info(" Batch size = %d", args.eval_batch_size)
682 |
683 | input_ids = []
684 | input_mask = []
685 | segment_ids = []
686 | label_id = []
687 |
688 | for f in eval_features:
689 | input_ids.append([])
690 | input_mask.append([])
691 | segment_ids.append([])
692 | for i in range(n_class):
693 | input_ids[-1].append(f[i].input_ids)
694 | input_mask[-1].append(f[i].input_mask)
695 | segment_ids[-1].append(f[i].segment_ids)
696 | label_id.append([f[0].label_id])
697 |
698 | all_input_ids = torch.tensor(input_ids, dtype=torch.long)
699 | all_input_mask = torch.tensor(input_mask, dtype=torch.long)
700 | all_segment_ids = torch.tensor(segment_ids, dtype=torch.long)
701 | all_label_ids = torch.tensor(label_id, dtype=torch.long)
702 |
703 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
704 | if args.local_rank == -1:
705 | eval_sampler = SequentialSampler(eval_data)
706 | else:
707 | eval_sampler = DistributedSampler(eval_data)
708 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
709 |
710 | model.eval()
711 | eval_loss, eval_accuracy = 0, 0
712 | nb_eval_steps, nb_eval_examples = 0, 0
713 | logits_all = []
714 | for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
715 | input_ids = input_ids.to(device)
716 | input_mask = input_mask.to(device)
717 | segment_ids = segment_ids.to(device)
718 | label_ids = label_ids.to(device)
719 |
720 | with torch.no_grad():
721 | tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids, n_class)
722 |
723 | logits = logits.detach().cpu().numpy()
724 | label_ids = label_ids.to('cpu').numpy()
725 | for i in range(len(logits)):
726 | logits_all += [logits[i]]
727 |
728 | tmp_eval_accuracy = accuracy(logits, label_ids.reshape(-1))
729 |
730 | eval_loss += tmp_eval_loss.mean().item()
731 | eval_accuracy += tmp_eval_accuracy
732 |
733 | nb_eval_examples += input_ids.size(0)
734 | nb_eval_steps += 1
735 |
736 | eval_loss = eval_loss / nb_eval_steps
737 | eval_accuracy = eval_accuracy / nb_eval_examples
738 |
739 | if args.do_train:
740 | result = {'eval_loss': eval_loss,
741 | 'eval_accuracy': eval_accuracy,
742 | 'global_step': global_step,
743 | 'loss': tr_loss/nb_tr_steps}
744 | else:
745 | result = {'eval_loss': eval_loss,
746 | 'eval_accuracy': eval_accuracy}
747 |
748 |
749 | output_eval_file = os.path.join(args.output_dir, "eval_results_test.txt")
750 | with open(output_eval_file, "w") as writer:
751 | logger.info("***** Eval results *****")
752 | for key in sorted(result.keys()):
753 | logger.info(" %s = %s", key, str(result[key]))
754 | writer.write("%s = %s\n" % (key, str(result[key])))
755 | output_eval_file = os.path.join(args.output_dir, "logits_test.txt")
756 | with open(output_eval_file, "w") as f:
757 | for i in range(len(logits_all)):
758 | for j in range(len(logits_all[i])):
759 | f.write(str(logits_all[i][j]))
760 | if j == len(logits_all[i])-1:
761 | f.write("\n")
762 | else:
763 | f.write(" ")
764 |
765 | if __name__ == "__main__":
766 | main()
767 |
--------------------------------------------------------------------------------
/bert/tokenization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import re
23 | import unicodedata
24 | import six
25 |
26 |
27 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
28 | """Checks whether the casing config is consistent with the checkpoint name."""
29 |
30 | # The casing has to be passed in by the user and there is no explicit check
31 | # as to whether it matches the checkpoint. The casing information probably
32 | # should have been stored in the bert_config.json file, but it's not, so
33 | # we have to heuristically detect it to validate.
34 |
35 | if not init_checkpoint:
36 | return
37 |
38 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
39 | if m is None:
40 | return
41 |
42 | model_name = m.group(1)
43 |
44 | lower_models = [
45 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
46 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
47 | ]
48 |
49 | cased_models = [
50 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
51 | "multi_cased_L-12_H-768_A-12"
52 | ]
53 |
54 | is_bad_config = False
55 | if model_name in lower_models and not do_lower_case:
56 | is_bad_config = True
57 | actual_flag = "False"
58 | case_name = "lowercased"
59 | opposite_flag = "True"
60 |
61 | if model_name in cased_models and do_lower_case:
62 | is_bad_config = True
63 | actual_flag = "True"
64 | case_name = "cased"
65 | opposite_flag = "False"
66 |
67 | if is_bad_config:
68 | raise ValueError(
69 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
70 | "However, `%s` seems to be a %s model, so you "
71 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches "
72 | "how the model was pre-training. If this error is wrong, please "
73 | "just comment out this check." % (actual_flag, init_checkpoint,
74 | model_name, case_name, opposite_flag))
75 |
76 |
77 | def convert_to_unicode(text):
78 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
79 | if six.PY3:
80 | if isinstance(text, str):
81 | return text
82 | elif isinstance(text, bytes):
83 | return text.decode("utf-8", "ignore")
84 | else:
85 | raise ValueError("Unsupported string type: %s" % (type(text)))
86 | elif six.PY2:
87 | if isinstance(text, str):
88 | return text.decode("utf-8", "ignore")
89 | elif isinstance(text, unicode):
90 | return text
91 | else:
92 | raise ValueError("Unsupported string type: %s" % (type(text)))
93 | else:
94 | raise ValueError("Not running on Python2 or Python 3?")
95 |
96 |
97 | def printable_text(text):
98 | """Returns text encoded in a way suitable for print or `tf.logging`."""
99 |
100 | # These functions want `str` for both Python2 and Python3, but in one case
101 | # it's a Unicode string and in the other it's a byte string.
102 | if six.PY3:
103 | if isinstance(text, str):
104 | return text
105 | elif isinstance(text, bytes):
106 | return text.decode("utf-8", "ignore")
107 | else:
108 | raise ValueError("Unsupported string type: %s" % (type(text)))
109 | elif six.PY2:
110 | if isinstance(text, str):
111 | return text
112 | elif isinstance(text, unicode):
113 | return text.encode("utf-8")
114 | else:
115 | raise ValueError("Unsupported string type: %s" % (type(text)))
116 | else:
117 | raise ValueError("Not running on Python2 or Python 3?")
118 |
119 |
120 | def load_vocab(vocab_file):
121 | """Loads a vocabulary file into a dictionary."""
122 | vocab = collections.OrderedDict()
123 | index = 0
124 | with open(vocab_file, "r", encoding='utf8') as reader:
125 | while True:
126 | token = convert_to_unicode(reader.readline())
127 | if not token:
128 | break
129 | token = token.strip()
130 | vocab[token] = index
131 | index += 1
132 | return vocab
133 |
134 |
135 | def convert_by_vocab(vocab, items):
136 | """Converts a sequence of [tokens|ids] using the vocab."""
137 | output = []
138 | for item in items:
139 | output.append(vocab[item])
140 | return output
141 |
142 |
143 | def convert_tokens_to_ids(vocab, tokens):
144 | return convert_by_vocab(vocab, tokens)
145 |
146 |
147 | def convert_ids_to_tokens(inv_vocab, ids):
148 | return convert_by_vocab(inv_vocab, ids)
149 |
150 |
151 | def whitespace_tokenize(text):
152 | """Runs basic whitespace cleaning and splitting on a piece of text."""
153 | text = text.strip()
154 | if not text:
155 | return []
156 | tokens = text.split()
157 | return tokens
158 |
159 |
160 | class FullTokenizer(object):
161 | """Runs end-to-end tokenziation."""
162 |
163 | def __init__(self, vocab_file, do_lower_case=True):
164 | self.vocab = load_vocab(vocab_file)
165 | self.inv_vocab = {v: k for k, v in self.vocab.items()}
166 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
167 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
168 |
169 | def tokenize(self, text):
170 | split_tokens = []
171 | for token in self.basic_tokenizer.tokenize(text):
172 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
173 | split_tokens.append(sub_token)
174 |
175 | return split_tokens
176 |
177 | def convert_tokens_to_ids(self, tokens):
178 | return convert_by_vocab(self.vocab, tokens)
179 |
180 | def convert_ids_to_tokens(self, ids):
181 | return convert_by_vocab(self.inv_vocab, ids)
182 |
183 |
184 | class BasicTokenizer(object):
185 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
186 |
187 | def __init__(self, do_lower_case=True):
188 | """Constructs a BasicTokenizer.
189 | Args:
190 | do_lower_case: Whether to lower case the input.
191 | """
192 | self.do_lower_case = do_lower_case
193 |
194 | def tokenize(self, text):
195 | """Tokenizes a piece of text."""
196 | text = convert_to_unicode(text)
197 | text = self._clean_text(text)
198 |
199 | # This was added on November 1st, 2018 for the multilingual and Chinese
200 | # models. This is also applied to the English models now, but it doesn't
201 | # matter since the English models were not trained on any Chinese data
202 | # and generally don't have any Chinese data in them (there are Chinese
203 | # characters in the vocabulary because Wikipedia does have some Chinese
204 | # words in the English Wikipedia.).
205 | text = self._tokenize_chinese_chars(text)
206 |
207 | orig_tokens = whitespace_tokenize(text)
208 | split_tokens = []
209 | for token in orig_tokens:
210 | if self.do_lower_case:
211 | token = token.lower()
212 | token = self._run_strip_accents(token)
213 | split_tokens.extend(self._run_split_on_punc(token))
214 |
215 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
216 | return output_tokens
217 |
218 | def _run_strip_accents(self, text):
219 | """Strips accents from a piece of text."""
220 | text = unicodedata.normalize("NFD", text)
221 | output = []
222 | for char in text:
223 | cat = unicodedata.category(char)
224 | if cat == "Mn":
225 | continue
226 | output.append(char)
227 | return "".join(output)
228 |
229 | def _run_split_on_punc(self, text):
230 | """Splits punctuation on a piece of text."""
231 | chars = list(text)
232 | i = 0
233 | start_new_word = True
234 | output = []
235 | while i < len(chars):
236 | char = chars[i]
237 | if _is_punctuation(char):
238 | output.append([char])
239 | start_new_word = True
240 | else:
241 | if start_new_word:
242 | output.append([])
243 | start_new_word = False
244 | output[-1].append(char)
245 | i += 1
246 |
247 | return ["".join(x) for x in output]
248 |
249 | def _tokenize_chinese_chars(self, text):
250 | """Adds whitespace around any CJK character."""
251 | output = []
252 | for char in text:
253 | cp = ord(char)
254 | if self._is_chinese_char(cp):
255 | output.append(" ")
256 | output.append(char)
257 | output.append(" ")
258 | else:
259 | output.append(char)
260 | return "".join(output)
261 |
262 | def _is_chinese_char(self, cp):
263 | """Checks whether CP is the codepoint of a CJK character."""
264 | # This defines a "chinese character" as anything in the CJK Unicode block:
265 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
266 | #
267 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
268 | # despite its name. The modern Korean Hangul alphabet is a different block,
269 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
270 | # space-separated words, so they are not treated specially and handled
271 | # like the all of the other languages.
272 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
273 | (cp >= 0x3400 and cp <= 0x4DBF) or #
274 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
275 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
276 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
277 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
278 | (cp >= 0xF900 and cp <= 0xFAFF) or #
279 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
280 | return True
281 |
282 | return False
283 |
284 | def _clean_text(self, text):
285 | """Performs invalid character removal and whitespace cleanup on text."""
286 | output = []
287 | for char in text:
288 | cp = ord(char)
289 | if cp == 0 or cp == 0xfffd or _is_control(char):
290 | continue
291 | if _is_whitespace(char):
292 | output.append(" ")
293 | else:
294 | output.append(char)
295 | return "".join(output)
296 |
297 |
298 | class WordpieceTokenizer(object):
299 | """Runs WordPiece tokenziation."""
300 |
301 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
302 | self.vocab = vocab
303 | self.unk_token = unk_token
304 | self.max_input_chars_per_word = max_input_chars_per_word
305 |
306 | def tokenize(self, text):
307 | """Tokenizes a piece of text into its word pieces.
308 | This uses a greedy longest-match-first algorithm to perform tokenization
309 | using the given vocabulary.
310 | For example:
311 | input = "unaffable"
312 | output = ["un", "##aff", "##able"]
313 | Args:
314 | text: A single token or whitespace separated tokens. This should have
315 | already been passed through `BasicTokenizer.
316 | Returns:
317 | A list of wordpiece tokens.
318 | """
319 |
320 | text = convert_to_unicode(text)
321 |
322 | output_tokens = []
323 | for token in whitespace_tokenize(text):
324 | chars = list(token)
325 | if len(chars) > self.max_input_chars_per_word:
326 | output_tokens.append(self.unk_token)
327 | continue
328 |
329 | is_bad = False
330 | start = 0
331 | sub_tokens = []
332 | while start < len(chars):
333 | end = len(chars)
334 | cur_substr = None
335 | while start < end:
336 | substr = "".join(chars[start:end])
337 | if start > 0:
338 | substr = "##" + substr
339 | if substr in self.vocab:
340 | cur_substr = substr
341 | break
342 | end -= 1
343 | if cur_substr is None:
344 | is_bad = True
345 | break
346 | sub_tokens.append(cur_substr)
347 | start = end
348 |
349 | if is_bad:
350 | output_tokens.append(self.unk_token)
351 | else:
352 | output_tokens.extend(sub_tokens)
353 | return output_tokens
354 |
355 |
356 | def _is_whitespace(char):
357 | """Checks whether `chars` is a whitespace character."""
358 | # \t, \n, and \r are technically contorl characters but we treat them
359 | # as whitespace since they are generally considered as such.
360 | if char == " " or char == "\t" or char == "\n" or char == "\r":
361 | return True
362 | cat = unicodedata.category(char)
363 | if cat == "Zs":
364 | return True
365 | return False
366 |
367 |
368 | def _is_control(char):
369 | """Checks whether `chars` is a control character."""
370 | # These are technically control characters but we count them as whitespace
371 | # characters.
372 | if char == "\t" or char == "\n" or char == "\r":
373 | return False
374 | cat = unicodedata.category(char)
375 | if cat.startswith("C"):
376 | return True
377 | return False
378 |
379 |
380 | def _is_punctuation(char):
381 | """Checks whether `chars` is a punctuation character."""
382 | cp = ord(char)
383 | # We treat all non-letter/number ASCII as punctuation.
384 | # Characters such as "^", "$", and "`" are not in the Unicode
385 | # Punctuation class but we treat them as punctuation anyways, for
386 | # consistency.
387 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
388 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
389 | return True
390 | cat = unicodedata.category(char)
391 | if cat.startswith("P"):
392 | return True
393 | return False
394 |
--------------------------------------------------------------------------------
/license.txt:
--------------------------------------------------------------------------------
1 | C3 dataset is intended for non-commercial research purpose only.
2 |
--------------------------------------------------------------------------------