├── .gitignore
├── LICENSE
├── README.md
├── collect.py
├── data
├── conllpp_dev.txt
├── conllpp_test.txt
└── conllpp_train.txt
├── example.sh
├── flair_scripts
├── flair_ner.py
└── sequence_tagger_with_weights.py
├── img
├── CrossWeigh.jpg
└── CrossWeigh.pdf
├── requirements.txt
└── split.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # General
2 | .DS_Store
3 | .AppleDouble
4 | .LSOverride
5 |
6 | # Icon must end with two \r
7 | Icon
8 |
9 | # Thumbnails
10 | ._*
11 |
12 | # Files that might appear in the root of a volume
13 | .DocumentRevisions-V100
14 | .fseventsd
15 | .Spotlight-V100
16 | .TemporaryItems
17 | .Trashes
18 | .VolumeIcon.icns
19 | .com.apple.timemachine.donotpresent
20 |
21 | # Directories potentially created on remote AFP share
22 | .AppleDB
23 | .AppleDesktop
24 | Network Trash Folder
25 | Temporary Items
26 | .apdisk
27 |
28 | __pycache__/
29 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [2021] [Zihan Wang]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CrossWeigh
2 |
CrossWeigh
3 | CrossWeigh: Training Named Entity Tagger from Imperfect Annotations
4 |
5 | ## Motivation
6 |
7 | The label annotation mistakes by human annotators brings up two challenges to NER:
8 | - mistakes in the test set can interfere the evaluation results and even lead to an inaccurate assessment of model performance.
9 | - mistakes in the training set can hurt NER model training.
10 |
11 | We address these two problems by:
12 | - manually correcting the mistakes in the test set to form a cleaner benchmark.
13 | - develop framework `CrossWeigh` to handle the mistakes in the training set.
14 | 
15 |
16 | `CrossWeigh` works with any NER algorithm that accepts weighted training instances. It
17 | is composed of two modules. 1) mistake estimation: where potential mistakes are identified in the training
18 | data through a cross-checking process and 2) mistake re-weighing: where weights of those mistakes are lowered
19 | during training the final NER model.
20 |
21 | ## Data
22 | We formally name our corrected dataset as CoNLL++.
23 | `/data/conllpp_test.txt` is the manually corrected test set, there should be exactly 186 sentences that
24 | differ from the original test set.
25 | `/data/conllpp_train.txt` and `/data/conllpp_dev.txt` are the original dataset of CoNLL03 from [Named-Entity-Recognition-NER-Papers](https://github.com/pfliu-nlp/Named-Entity-Recognition-NER-Papers).
26 |
27 | ## Scripts
28 | `split.py` can be used to generate a k-fold entity disjoint dataset from a list of datasets(usually both the train and development set)
29 | `flair_scripts/flair_ner.py` can be used to train a weighted version of flair.
30 | `collect.py` can be used to collect all the predictions on the k folded test set.
31 |
32 | ## Steps to reproduce
33 | Make sure you are in a python3.6+ environment.
34 | See [example.sh](example.sh) to reproduce the results.
35 | Using [Flair](https://github.com/zalandoresearch/flair) (non-pooled version), the final result should achieve
36 | around 93.19F1 on the original test dataset and 94.18F1 on the corrected test set. Using Flair without CrossWeigh gives
37 | around 92.9F1.
38 |
39 | ## Results
40 | All the results are averaged across 5 runs and standard deviation is reported.
41 |
42 | | Model | w/o CrossWeigh (original) | w/ CrossWeigh (original) | w/o CrossWeigh (corrected) | w/ CrossWeigh (corrected) |
43 | | --- | :-------------: | :-------------: | :-------------: | :-------------: |
44 | | [VanillaNER](https://github.com/LiyuanLucasLiu/Vanilla_NER) | 91.44(±0.16) | **91.78**(±**0.06**) | 92.32(±0.16) | **92.64**(±**0.08**) |
45 | | [Flair](https://github.com/zalandoresearch/flair) | 92.87(±**0.08**) | **93.19**(±0.09) | 93.89(±**0.06**) | **94.18**(±**0.06**) |
46 | | [Pooled-Flair](https://github.com/zalandoresearch/flair) | 93.14(±0.14) | **93.43**(±**0.06**) | 94.13(±0.11) | **94.28**(±**0.05**) |
47 | | [GCDT](https://github.com/Adaxry/GCDT) | 93.33(±0.14) | **93.43**(±**0.05**) | 94.58(±0.15) | **94.65**(±**0.06**) |
48 | | [LSTM-CRF](https://github.com/glample/tagger) | 90.64(±0.23) | | 91.47(±0.15) | |
49 | | [LSTM-CNNs-CRF](https://github.com/XuezheMax/NeuroNLP2) | 90.65(±0.57) | | 91.87(±0.50) | |
50 | | [ELMo](https://github.com/allenai/allennlp/blob/master/tutorials/how_to/elmo.md) | 92.28(±0.19) | | 93.42(±0.15) | |
51 |
52 | For all models, we use their suggested parameter settings.
53 | For GCDT, we used the weights estimated from Pooled-Flair for efficiency purposes.
54 |
55 | ## Citation
56 | Please cite the following paper if you found our dataset or framework useful. Thanks!
57 |
58 | >Zihan Wang, Jingbo Shang, Liyuan Liu, Lihao Lu, Jiacheng Liu, and Jiawei Han. "CrossWeigh: Training Named Entity Tagger from Imperfect Annotations." arXiv preprint arXiv:1909.01441 (2019).
59 |
60 | ```
61 | @article{wang2019cross,
62 | title={CrossWeigh: Training Named Entity Tagger from Imperfect Annotations},
63 | author={Wang, Zihan and Shang, Jingbo and Liu, Liyuan and Lu, Lihao and Liu, Jiacheng and Han, Jiawei},
64 | journal={arXiv preprint arXiv:1909.01441},
65 | year={2019}
66 | }
67 | ```
68 |
--------------------------------------------------------------------------------
/collect.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import glob
4 | from collections import defaultdict
5 | from split import load_dataset_from_column
6 |
7 |
8 | # compares `original_test_filename` with `model_predicted_filename` under all paths, and merge the results
9 | # paths could be like ['splitted_0/fold-0', 'splitted_0/fold-1', ..., 'splitted_1/fold-0'...]
10 | def load_from_splits(paths, original_test_filename, model_predicted_filename):
11 | sentence_potential_mistake_count = defaultdict(int)
12 | for path in paths:
13 | original_test = os.path.join(path, original_test_filename)
14 | model_predicted = os.path.join(path, model_predicted_filename)
15 | assert os.path.exists(original_test)
16 | assert os.path.exists(model_predicted)
17 | original_test = load_dataset_from_column(original_test)
18 | model_predicted = load_dataset_from_column(model_predicted, schema="none") # since there may be invalid label sequences.
19 | for (original_sentence, original_labels), (model_sentence, model_labels) in zip(original_test, model_predicted):
20 | assert ' '.join(original_sentence) == ' '.join(model_sentence)
21 | if ' '.join(original_labels) != ' '.join(model_labels):
22 | sentence_potential_mistake_count[' '.join(original_sentence)] += 1
23 | return sentence_potential_mistake_count
24 |
25 |
26 | def form_weighted_train_set(train_files, train_file_schema, eps, mistake_count):
27 | for train_file in train_files:
28 | assert os.path.exists(train_file)
29 | train_set = []
30 | for train_file in train_files:
31 | train_set.extend(load_dataset_from_column(train_file, schema=train_file_schema))
32 |
33 | weighted_train_set = []
34 | for sentence, labels in train_set:
35 | mistakes = mistake_count.get(' '.join(sentence), 0)
36 | weight = eps ** mistakes
37 | weighted_train_set.append([sentence, labels, [weight] * len(labels)])
38 | return weighted_train_set
39 |
40 |
41 | def main(split_folders, train_files, train_file_schema, output_weighted_train_file, model_predicted_filename, eps):
42 | for split_folder in split_folders:
43 | assert os.path.exists(split_folder)
44 | assert not os.path.exists(output_weighted_train_file)
45 | paths = []
46 | for split_folder in split_folders:
47 | paths.extend(glob.glob(os.path.join(split_folder, 'fold-*')))
48 | sentence_potential_mistake_count = load_from_splits(paths, 'test.bio', model_predicted_filename)
49 | weighted_train_set = form_weighted_train_set(train_files, train_file_schema, eps, sentence_potential_mistake_count)
50 | with open(output_weighted_train_file, 'w') as f:
51 | for sentence, labels, weights in weighted_train_set:
52 | for token, label, weight in zip(sentence, labels, weights):
53 | f.write(f'{token}\t{label}\t{weight}\n')
54 | f.write('\n')
55 |
56 |
57 | if __name__ == '__main__':
58 | parser = argparse.ArgumentParser()
59 | parser.add_argument('--split_folders', nargs='+', required=True)
60 | parser.add_argument('--train_files', nargs='+', required=True)
61 | parser.add_argument('--train_file_schema', default="bio", choices=["bio", "iob", "iobes", "none"])
62 | parser.add_argument('--output', required=True)
63 | parser.add_argument('--model_predicted_filename', default='predict.bio')
64 | parser.add_argument('--eps', type=float, default=0.7)
65 | args = parser.parse_args()
66 | print(vars(args))
67 | main(args.split_folders, args.train_files, args.train_file_schema, args.output, args.model_predicted_filename, args.eps)
68 |
--------------------------------------------------------------------------------
/example.sh:
--------------------------------------------------------------------------------
1 | export CONLL03_TRAIN_FILE=data/conllpp_train.txt
2 | export CONLL03_DEV_FILE=data/conllpp_dev.txt
3 | export CONLL03_TEST_FILE=data/conllpp_test.txt
4 | export DATA_FOLDER_PREFIX=splitdata
5 | export MODEL_FOLDER_PREFIX=model
6 | export WEIGHED_MODEL_FOLDER_NAME=weighed
7 | mkdir -p ${DATA_FOLDER_PREFIX}/${WEIGHED_MODEL_FOLDER_NAME}
8 |
9 | # creating splits
10 | for splits in $(seq 1 1 3); do
11 | SPLIT_FOLDER=${DATA_FOLDER_PREFIX}/split-${splits}
12 | python split.py --input_files ${CONLL03_TRAIN_FILE} ${CONLL03_DEV_FILE} \
13 | --output_folder ${SPLIT_FOLDER} \
14 | --schema iob \
15 | --folds 10
16 | done
17 |
18 | # training each split/fold
19 | for splits in $(seq 1 1 3); do
20 | for folds in $(seq 0 1 9); do
21 | FOLD_FOLDER=split-${splits}/fold-${folds}
22 | python flair_scripts/flair_ner.py --folder_name ${FOLD_FOLDER} \
23 | --data_folder_prefix ${DATA_FOLDER_PREFIX} \
24 | --model_folder_prefix ${MODEL_FOLDER_PREFIX}
25 | done
26 | done
27 |
28 | # collecting results and forming a weighted train set.
29 | python collect.py --split_folders ${DATA_FOLDER_PREFIX}/split-* \
30 | --train_files $CONLL03_TRAIN_FILE $CONLL03_DEV_FILE \
31 | --train_file_schema iob \
32 | --output ${DATA_FOLDER_PREFIX}/${WEIGHED_MODEL_FOLDER_NAME}/train.bio
33 |
34 | # train the final model
35 | python flair_scripts/flair_ner.py --folder_name ${WEIGHED_MODEL_FOLDER_NAME} \
36 | --data_folder_prefix ${DATA_FOLDER_PREFIX} \
37 | --model_folder_prefix ${MODEL_FOLDER_PREFIX} \
38 | --include_weight
39 |
--------------------------------------------------------------------------------
/flair_scripts/flair_ner.py:
--------------------------------------------------------------------------------
1 | from flair.data import Corpus
2 | from flair.data_fetcher import NLPTaskDataFetcher
3 | from flair.embeddings import TokenEmbeddings, WordEmbeddings, StackedEmbeddings, FlairEmbeddings, PooledFlairEmbeddings
4 | from flair.trainers import ModelTrainer
5 | from flair.models import SequenceTagger
6 | from typing import List
7 | import argparse
8 | import os
9 | from sequence_tagger_with_weights import WeightedSequenceTagger
10 |
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument('--folder_name', required=True)
13 | parser.add_argument('--include_weight', action='store_true')
14 | parser.add_argument('--data_folder_prefix')
15 | parser.add_argument('--model_folder_prefix')
16 | args = parser.parse_args()
17 | print(vars(args))
18 |
19 | column_format = {0: 'text', 1: 'ner'} # the datafiles generated by our scripts have columns: text ner [weight]
20 | if args.include_weight:
21 | column_format[2] = 'weight'
22 |
23 | # this can be modified to individual needs.
24 | data_folder = os.path.join(args.data_folder_prefix, args.folder_name)
25 | model_folder = os.path.join(args.model_folder_prefix, args.folder_name)
26 |
27 | if args.include_weight:
28 | model_folder += '_w'
29 | # print(column_format)
30 | corpus: Corpus = NLPTaskDataFetcher.load_column_corpus(data_folder,
31 | column_format=column_format,
32 | tag_to_biloes="ner")
33 |
34 | tag_type = 'ner'
35 |
36 | tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type)
37 |
38 | embedding_types: List[TokenEmbeddings] = [
39 |
40 | # GloVe embeddings
41 | WordEmbeddings('glove'),
42 |
43 | # contextual string embeddings, forward
44 | FlairEmbeddings('news-forward'),
45 | # PooledFlairEmbeddings('news-forward', pooling='min'),
46 |
47 | # contextual string embeddings, backward
48 | FlairEmbeddings('news-backward'),
49 | # PooledFlairEmbeddings('news-backward', pooling='min'),
50 | ]
51 |
52 | embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embedding_types)
53 |
54 | tagger: SequenceTagger = WeightedSequenceTagger(hidden_size=256,
55 | embeddings=embeddings,
56 | tag_dictionary=tag_dictionary,
57 | tag_type=tag_type)
58 |
59 | trainer: ModelTrainer = ModelTrainer(tagger, corpus)
60 |
61 | trainer.train(model_folder,
62 | max_epochs=150,
63 | monitor_test=True,
64 | train_with_dev=True)
65 |
66 |
67 | def get_tokens_and_labels(sentence):
68 | tokens = []
69 | labels = []
70 | for token in sentence.tokens:
71 | tokens.append(token.text)
72 | labels.append(token.get_tag("ner").value)
73 | return tokens, labels
74 |
75 |
76 | def iobes2bio(iobes_labels):
77 | bio_labels = []
78 | for label in iobes_labels:
79 | if label[0] == 'S':
80 | bio_labels.append('B' + label[1:])
81 | elif label[0] == 'E':
82 | bio_labels.append('I' + label[1:])
83 | else:
84 | bio_labels.append(label)
85 | return bio_labels
86 |
87 |
88 | tagger = SequenceTagger.load(os.path.join(model_folder, 'final-model.pt'))
89 |
90 | test_sentences = [x for x in corpus.test]
91 | tagger.predict(test_sentences)
92 | sentences = []
93 | for sentence in test_sentences:
94 | tokens, labels = get_tokens_and_labels(sentence)
95 | labels = iobes2bio(labels)
96 | sentences.append((tokens, labels))
97 | with open(os.path.join(data_folder, 'predict.bio'), 'w') as f:
98 | for tokens, labels in sentences:
99 | for token, label in zip(tokens, labels):
100 | f.write(f'{token}\t{label}\n')
101 | f.write('\n')
102 |
--------------------------------------------------------------------------------
/flair_scripts/sequence_tagger_with_weights.py:
--------------------------------------------------------------------------------
1 | from flair.models import SequenceTagger
2 | from flair.models.sequence_tagger_model import pad_tensors
3 | import torch
4 | from typing import List
5 | import flair
6 | from flair.data import Sentence
7 |
8 | class WeightedSequenceTagger(SequenceTagger):
9 | def _calculate_loss(
10 | self, features: torch.tensor, sentences: List[Sentence]
11 | ) -> float:
12 |
13 | lengths: List[int] = [len(sentence.tokens) for sentence in sentences]
14 |
15 | tag_list: List = []
16 | weight_list: List[float] = []
17 | for s_id, sentence in enumerate(sentences):
18 | # get the tags in this sentence
19 | tag_idx: List[int] = [
20 | self.tag_dictionary.get_idx_for_item(token.get_tag(self.tag_type).value)
21 | for token in sentence
22 | ]
23 | # add tags as tensor
24 | tag = torch.tensor(tag_idx, device=flair.device)
25 | tag_list.append(tag)
26 | try:
27 | weight = sentence.tokens[0].get_tag("weight").value
28 | weight_list.append(float(weight))
29 | except:
30 | weight_list.append(1.0)
31 |
32 | if self.use_crf:
33 | # pad tags if using batch-CRF decoder
34 | tags, _ = pad_tensors(tag_list)
35 |
36 | forward_score = self._forward_alg(features, lengths)
37 | gold_score = self._score_sentence(features, tags, lengths)
38 |
39 | score = forward_score - gold_score
40 |
41 | weight_list = torch.tensor(weight_list, device=flair.device)
42 | score = score * weight_list
43 | return score.mean()
44 |
45 | else:
46 | score = 0
47 | for sentence_feats, sentence_tags, sentence_length in zip(
48 | features, tag_list, lengths
49 | ):
50 | sentence_feats = sentence_feats[:sentence_length]
51 |
52 | score += torch.nn.functional.cross_entropy(
53 | sentence_feats, sentence_tags
54 | )
55 | score /= len(features)
56 | return score
57 |
--------------------------------------------------------------------------------
/img/CrossWeigh.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZihanWangKi/CrossWeigh/77d0c2dbf1a039211426f0fecb11daabba8e7bd1/img/CrossWeigh.jpg
--------------------------------------------------------------------------------
/img/CrossWeigh.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZihanWangKi/CrossWeigh/77d0c2dbf1a039211426f0fecb11daabba8e7bd1/img/CrossWeigh.pdf
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.4.0
2 | flair==0.7.0
3 |
--------------------------------------------------------------------------------
/split.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import random
4 | import json
5 | import argparse
6 |
7 |
8 | def validate_bio(labels):
9 | for cur_label, next_label in zip(labels, labels[1:] + ['O']):
10 | if cur_label[0] == 'O':
11 | assert next_label[0] == 'O' or next_label[0] == 'B'
12 | continue
13 | elif cur_label[0] == 'B':
14 | assert next_label[0] == 'O' or next_label[0] == 'B' or (
15 | next_label[0] == 'I' and cur_label[1:] == next_label[1:])
16 | elif cur_label[0] == 'I':
17 | assert next_label[0] == 'O' or next_label[0] == 'B' or \
18 | (next_label[0] == 'I' and cur_label[1:] == next_label[1:])
19 | else:
20 | assert False
21 |
22 |
23 | def iob2bio(iob_labels):
24 | bio_labels = []
25 | for prev_label, cur_label in zip(['O'] + iob_labels[:-1], iob_labels):
26 | if (prev_label[0] == 'O' and cur_label[0] == 'I') or (prev_label[0] != 'O' and
27 | cur_label[0] == 'I' and
28 | prev_label[2:] != cur_label[2:]):
29 | bio_labels.append('B' + cur_label[1:])
30 | else:
31 | bio_labels.append(cur_label)
32 | return bio_labels
33 |
34 | def iobes2bio(iobes_labels):
35 | bio_labels = []
36 | for label in iobes_labels:
37 | if label[0] == 'S':
38 | bio_labels.append('B' + label[1:])
39 | elif label[0] == 'E':
40 | bio_labels.append('I' + label[1:])
41 | else:
42 | bio_labels.append(label)
43 | return bio_labels
44 |
45 |
46 | # loads a column dataset into list of (tokens, labels)
47 | # assumes BIO(IOB2) labeling
48 | def load_dataset_from_column(path, schema='bio'):
49 | with open(path, 'r', encoding='utf-8') as f:
50 | sentences = []
51 | tokens = []
52 | labels = []
53 | for line in f.readlines() + ['']:
54 | if len(line) == 0 or line.startswith('-DOCSTART-') or line.isspace():
55 | if len(tokens) > 0:
56 | if schema is not None and schema != 'none':
57 | if schema == 'iob':
58 | labels = iob2bio(labels)
59 | elif schema == 'iobes':
60 | labels = iobes2bio(labels)
61 | validate_bio(labels)
62 | sentences.append((tokens, labels))
63 | tokens = []
64 | labels = []
65 | else:
66 | splits = line.strip().split()
67 | token, label = splits[0], splits[-1]
68 | tokens.append(token)
69 | labels.append(label)
70 | return sentences
71 |
72 |
73 | # given tokens, labels, extract list of spans of entities as (TYPE, START inc, END exc, SURFACE)
74 | def sent_label_to_entity(tokens, labels):
75 | assert len(tokens) == len(labels)
76 | entities = []
77 | cur_entity = {}
78 | for index, (label, next_label) in enumerate(zip(labels, labels[1:] + ['O'])):
79 | if label[0] == 'B':
80 | cur_entity['type'] = label[2:]
81 | cur_entity['start'] = index
82 | if next_label[0] != 'I' and len(cur_entity) > 0:
83 | cur_entity['end'] = index + 1
84 | cur_entity['surface'] = ' '.join(tokens[cur_entity['start']: cur_entity['end']])
85 | entities.append(cur_entity)
86 | cur_entity = {}
87 | return entities
88 |
89 |
90 | # sentence_entities: [[e1, e2, e3], [e2, e4, e5], [e1, e3], ...]...
91 | # folds: int
92 | # random_seed: int
93 | def create_folds(sentence_entities, folds, random_seed):
94 | random.seed(random_seed)
95 | data_size = len(sentence_entities)
96 | indexs = list(range(data_size))
97 | info = {'seed': random_seed, 'folds': folds, 'indexs': indexs}
98 | random.shuffle(indexs)
99 | for i in range(folds):
100 | test_data_indexs = indexs[i::folds]
101 | train_data_indexs = [indexs[x::folds] for x in range(folds) if x != i]
102 | train_data_indexs = [x for y in train_data_indexs for x in y]
103 | forbid_entities = set().union(*[set(sentence_entities[x]) for x in test_data_indexs])
104 | train_data_indexs = list(
105 | filter(lambda x: set(sentence_entities[x]).isdisjoint(forbid_entities), train_data_indexs))
106 | assert set(test_data_indexs).isdisjoint(set(train_data_indexs))
107 | assert set().union(*[set(sentence_entities[x]) for x in test_data_indexs]).isdisjoint(
108 | set().union(*[set(sentence_entities[x]) for x in train_data_indexs]))
109 | _info = {
110 | 'train_indexs': train_data_indexs,
111 | 'test_indexs': test_data_indexs,
112 | 'train_sentences': len(train_data_indexs),
113 | 'train_total_entities': sum(len(sentence_entities[x]) for x in train_data_indexs),
114 | 'train_distinct_entities': len(set().union(*[set(sentence_entities[x]) for x in train_data_indexs])),
115 | 'test_sentences': len(test_data_indexs),
116 | 'test_total_entities': sum(len(sentence_entities[x]) for x in test_data_indexs),
117 | 'test_distinct_entities': len(set().union(*[set(sentence_entities[x]) for x in test_data_indexs])),
118 | }
119 | info[f'fold-{i}'] = _info
120 | print(f"Set {i}")
121 | print(f"Train sentences: {_info['train_sentences']}")
122 | print(f"Train total entities: {_info['train_total_entities']}")
123 | print(f"Train distinct entities: {_info['train_distinct_entities']}")
124 | print(f"Test sentences: {_info['test_sentences']}")
125 | print(f"Test total entities: {_info['test_total_entities']}")
126 | print(f"Test distinct entities: {_info['test_distinct_entities']}")
127 | return info
128 |
129 |
130 | def main(input_files, output_folder, folds, schema):
131 | if os.path.exists(output_folder):
132 | print(f"Output folder {output_folder} exists, exiting...")
133 | sys.exit(1)
134 | os.makedirs(output_folder, exist_ok=True)
135 | for input_file in input_files:
136 | if not os.path.exists(input_file):
137 | print(f"Input file {input_file} does not exist, exiting...")
138 | sys.exit(1)
139 | assert folds > 0
140 |
141 | all_data = []
142 | for input_file in input_files:
143 | all_data.extend(load_dataset_from_column(input_file, schema))
144 |
145 | sentence_entities = [list(map(lambda x: x['surface'], sent_label_to_entity(tokens, labels)))
146 | for tokens, labels in all_data]
147 |
148 | seed = random.randint(111111, 999999)
149 | info = create_folds(sentence_entities, folds, seed)
150 |
151 | for i in range(folds):
152 | train_indexs = info[f'fold-{i}']['train_indexs']
153 | test_indexs = info[f'fold-{i}']['test_indexs']
154 |
155 | os.makedirs(os.path.join(output_folder, f'fold-{i}'), exist_ok=True)
156 |
157 | with open(os.path.join(output_folder, f'fold-{i}', f'train.bio'), 'w') as f:
158 | for x in train_indexs:
159 | for token, label in zip(*all_data[x]):
160 | f.write(f'{token}\t{label}\n')
161 | f.write('\n')
162 | with open(os.path.join(output_folder, f'fold-{i}', f'test.bio'), 'w') as f:
163 | for x in test_indexs:
164 | for token, label in zip(*all_data[x]):
165 | f.write(f'{token}\t{label}\n')
166 | f.write('\n')
167 |
168 | with open(os.path.join(output_folder, 'info.json'), 'w') as f:
169 | json.dump(info, f, indent=2)
170 |
171 |
172 | if __name__ == '__main__':
173 | parser = argparse.ArgumentParser()
174 | # input files, separate with space, will concat them together
175 | parser.add_argument('--input_files', nargs='+', required=True)
176 | # output folder, will create per-fold folder in it
177 | parser.add_argument('--output_folder', required=True)
178 | # number of folds to make
179 | parser.add_argument('--folds', type=int, default=10)
180 | # label typing schema
181 | parser.add_argument('--schema', default="bio", choices=["bio", "iob", "iobes", "none"])
182 | args = parser.parse_args()
183 | print(vars(args))
184 | main(args.input_files, args.output_folder, args.folds, args.schema)
185 |
--------------------------------------------------------------------------------