├── .gitignore ├── README.md ├── data ├── SplitSentences │ ├── .classpath │ ├── .project │ ├── .settings │ │ ├── org.eclipse.core.resources.prefs │ │ ├── org.eclipse.core.runtime.prefs │ │ └── org.eclipse.jdt.core.prefs │ ├── bin │ │ ├── test │ │ │ ├── Main.class │ │ │ ├── Sentence.class │ │ │ └── SourceReader.class │ │ └── word │ │ │ └── WordCount.class │ ├── readme.md │ └── src │ │ ├── test │ │ ├── Main.java │ │ ├── Sentence.java │ │ └── SourceReader.java │ │ └── word │ │ └── WordCount.java ├── convert.py ├── dev.json ├── dev_filtered.json ├── kbp_sent.txt ├── kbp_vocab.txt ├── kbp_word_count.txt ├── run_info.log ├── test.json ├── test_filtered.json ├── train.json ├── train_filtered.json └── utils.py ├── download_pt_models.sh ├── requirements.txt └── src ├── ablation ├── arg.py ├── chunk_global_encoder.py ├── chunk_global_encoder.sh ├── data.py ├── modeling.py ├── without_global_encoder.py └── without_global_encoder.sh ├── analysis ├── run_analysis.py └── utils.py ├── clustering ├── arg.py ├── cluster.py ├── run_cluster.py ├── run_cluster.sh └── utils.py ├── global_event_coref ├── analysis.py ├── arg.py ├── data.py ├── modeling.py ├── run_global_base.py ├── run_global_base.sh ├── run_global_base_with_mask.py ├── run_global_base_with_mask.sh ├── run_global_base_with_mask_topic.py ├── run_global_base_with_mask_topic.sh ├── run_global_base_with_topic.py └── run_global_base_with_topic.sh ├── joint_model ├── arg.py ├── data.py ├── modeling.py ├── run_joint_base.py └── run_joint_base.sh ├── local_event_coref ├── arg.py ├── data.py ├── modeling.py ├── run_local_base.py ├── run_local_base.sh ├── run_local_base_with_mask.py ├── run_local_base_with_mask.sh ├── run_local_base_with_mask_topic.py ├── run_local_base_with_mask_topic.sh ├── run_local_base_with_topic.py └── run_local_base_with_topic.sh ├── tools.py └── trigger_detection ├── arg.py ├── data.py ├── modeling.py ├── run_td_crf.py ├── run_td_crf.sh ├── run_td_softmax.py └── run_td_softmax.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | cache/ 7 | results/ 8 | reference-coreference-scorers/ 9 | 10 | data/LDC_TAC_KBP/ 11 | 12 | # VS CODE 13 | .vscode/ 14 | 15 | # MACOS 16 | .DS_Store 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | build/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | wheels/ 35 | pip-wheel-metadata/ 36 | share/python-wheels/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | MANIFEST 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .nox/ 56 | .coverage 57 | .coverage.* 58 | .cache 59 | nosetests.xml 60 | coverage.xml 61 | *.cover 62 | *.py,cover 63 | .hypothesis/ 64 | .pytest_cache/ 65 | 66 | # Translations 67 | *.mo 68 | *.pot 69 | 70 | # Django stuff: 71 | local_settings.py 72 | db.sqlite3 73 | db.sqlite3-journal 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 106 | __pypackages__/ 107 | 108 | # Celery stuff 109 | celerybeat-schedule 110 | celerybeat.pid 111 | 112 | # SageMath parsed files 113 | *.sage.py 114 | 115 | # Environments 116 | .env 117 | .venv 118 | env/ 119 | venv/ 120 | ENV/ 121 | env.bak/ 122 | venv.bak/ 123 | 124 | # Spyder project settings 125 | .spyderproject 126 | .spyproject 127 | 128 | # Rope project settings 129 | .ropeproject 130 | 131 | # mkdocs documentation 132 | /site 133 | 134 | # mypy 135 | .mypy_cache/ 136 | .dmypy.json 137 | dmypy.json 138 | 139 | # Pyre type checker 140 | .pyre/ 141 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Improving Event Coreference Resolution Using Document-level and Topic-level Information 2 | 3 | This code was used in the paper: 4 | 5 | **"[Improving Event Coreference Resolution Using Document-level and Topic-level Information](https://aclanthology.org/2022.emnlp-main.454/)"** 6 | Sheng Xu, Peifeng Li and Qiaoming Zhu. EMNLP 2022. 7 | 8 | A simple pipeline model implemented in PyTorch for resolving within-document event coreference. The model was trained and evaluated on the KBP corpus. 9 | 10 | ## Set up 11 | 12 | #### Requirements 13 | 14 | Set up a Python virtual environment and run: 15 | 16 | ```bash 17 | python3 -m pip install -r requirements.txt 18 | ``` 19 | 20 | #### Download the evaluation script 21 | 22 | Coreference results are obtained using official [**Reference Coreference Scorer**](https://github.com/conll/reference-coreference-scorers). This scorer reports results in terms of AVG-F, which is the unweighted average of the F-scores of four commonly used coreference evaluation metrics, namely $\text{MUC}$ ([Vilain et al., 1995](https://www.aclweb.org/anthology/M95-1005/)), $B^3$ ([Bagga and Baldwin, 1998](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.47.5848&rep=rep1&type=pdf)), $\text{CEAF}_e$ ([Luo, 2005](https://www.aclweb.org/anthology/H05-1004/)) and $\text{BLANC}$ ([Recasens and Hovy, 2011](https://www.researchgate.net/profile/Eduard-Hovy/publication/231881781_BLANC_Implementing_the_Rand_index_for_coreference_evaluation/links/553122420cf2f2a588acdc95/BLANC-Implementing-the-Rand-index-for-coreference-evaluation.pdf)). 23 | 24 | Run (from inside the repo): 25 | 26 | ```bash 27 | cd ./ 28 | git clone git@github.com:conll/reference-coreference-scorers.git 29 | ``` 30 | 31 | #### Download pretrained models 32 | 33 | Download the pretrained model weights (e.g. `bert-base-cased`) from Huggingface [Model Hub](https://huggingface.co/models): 34 | 35 | ```bash 36 | bash download_pt_models.sh 37 | ``` 38 | 39 | **Note:** this script will download all pretrained models used in our experiment in `../PT_MODELS/`. 40 | 41 | #### Prepare the dataset 42 | 43 | This repo assumes access to the English corpora used in TAC KBP Event Nugget Detection and Coreference task (i.e., [KBP 2015](http://cairo.lti.cs.cmu.edu/kbp/2015/event/), [KBP 2016](http://cairo.lti.cs.cmu.edu/kbp/2016/event/), and [KBP 2017](http://cairo.lti.cs.cmu.edu/kbp/2017/event/)). In total, they contain 648 + 169 + 167 = 984 documents, which are either newswire articles or discussion forum threads. 44 | 45 | ``` 46 | '2015': [ 47 | 'LDC_TAC_KBP/LDC2015E29/data/', 48 | 'LDC_TAC_KBP/LDC2015E68/data/', 49 | 'LDC_TAC_KBP/LDC2017E02/data/2015/training/', 50 | 'LDC_TAC_KBP/LDC2017E02/data/2015/eval/' 51 | ], 52 | '2016': [ 53 | 'LDC_TAC_KBP/LDC2017E02/data/2016/eval/eng/nw/', 54 | 'LDC_TAC_KBP/LDC2017E02/data/2016/eval/eng/df/' 55 | ], 56 | '2017': [ 57 | 'LDC_TAC_KBP/LDC2017E54/data/eng/nw/', 58 | 'LDC_TAC_KBP/LDC2017E54/data/eng/df/' 59 | ] 60 | ``` 61 | 62 | | | KBP 2015 | KBP 2016 | KBP 2017 | All | 63 | | ---------------- | :------: | :-------: | :------: | :---: | 64 | | \#Documents | 648 | 169 | 167 | 984 | 65 | | \#Event mentions | 18739 | 4155 | 4375 | 27269 | 66 | | \#Event Clusters | 11603 | 3191 | 2963 | 17757 | 67 | 68 | Following ([Lu & Ng, 2021](https://aclanthology.org/2021.emnlp-main.103/)), we select LDC2015E29, E68, E73, E94 and LDC2016E64 as train set (817 docs, 735 for training and the remaining 82 for parameter tuning), and report results on the KBP 2017 dataset. 69 | 70 | **Dataset Statistics:** 71 | 72 | | | Train | Dev | Test | All | 73 | | ---------------- | :---: | :--: | :--: | :---: | 74 | | \#Documents | 735 | 82 | 167 | 984 | 75 | | \#Event mentions | 20512 | 2382 | 4375 | 27269 | 76 | | \#Event Clusters | 13292 | 1502 | 2963 | 17757 | 77 | 78 | Then, 79 | 80 | 1. Split sentences and count verbs/entities in documents using Stanford CoreNLP (see [readme](data/SplitSentences/readme.md)), creating `kbp_sent.txt` and `kbp_word_count.txt` in the *data* folder. 81 | 82 | 2. Convert the original dataset into jsonlines format using: 83 | 84 | ```bash 85 | cd data/ 86 | 87 | export DATA_DIR= 88 | python3 convert.py --kbp_data_dir $DATA_DIR 89 | ``` 90 | 91 | **Note:** this script will create `train.json`、`dev.json` and `test.json` in the *data* folder, as well as `train_filtered.json`、`dev_filtered.json` and `test_filtered.json` which filter same position and overlapping event mentions. 92 | 93 | ## Training 94 | 95 | #### Trigger Detection 96 | 97 | Train a sequence labeling model for Trigger Detection using the BIO tagging schema (Run with `--do_train`): 98 | 99 | ```bash 100 | cd src/trigger_detection/ 101 | 102 | export OUTPUT_DIR=./softmax_ce_results/ 103 | 104 | python3 run_td_softmax.py \ 105 | --output_dir=$OUTPUT_DIR \ 106 | --model_type=longformer \ 107 | --model_checkpoint=../../PT_MODELS/allenai/longformer-large-4096/ \ 108 | --train_file=../../data/train_filtered.json \ 109 | --dev_file=../../data/dev_filtered.json \ 110 | --test_file=../../data/test_filtered.json \ 111 | --max_seq_length=4096 \ 112 | --learning_rate=1e-5 \ 113 | --softmax_loss=ce \ 114 | --num_train_epochs=50 \ 115 | --batch_size=1 \ 116 | --do_train \ 117 | --warmup_proportion=0. \ 118 | --seed=42 119 | ``` 120 | 121 | After training, the model weights and the evaluation results on **Dev** set would be saved in `$OUTPUT_DIR`. 122 | 123 | #### Event Coreference 124 | 125 | Train the full version of our event coreference model using (Run with `--do_train`): 126 | 127 | ```bash 128 | cd src/global_event_coref/ 129 | 130 | export OUTPUT_DIR=./MaskTopic_M-multi-cosine_results/ 131 | 132 | python3 run_global_base_with_mask_topic.py \ 133 | --output_dir=$OUTPUT_DIR \ 134 | --model_type=longformer \ 135 | --model_checkpoint=../../../PT_MODELS/allenai/longformer-large-4096/ \ 136 | --mention_encoder_type=bert \ 137 | --mention_encoder_checkpoint=../../../PT_MODELS/bert-base-cased/ \ 138 | --topic_model=vmf \ 139 | --topic_dim=32 \ 140 | --topic_inter_map=64 \ 141 | --train_file=../../data/train_filtered.json \ 142 | --dev_file=../../data/dev_filtered.json \ 143 | --test_file=../../data/test_filtered.json \ 144 | --max_seq_length=4096 \ 145 | --max_mention_length=256 \ 146 | --learning_rate=1e-5 \ 147 | --matching_style=multi_cosine \ 148 | --softmax_loss=ce \ 149 | --num_train_epochs=50 \ 150 | --batch_size=1 \ 151 | --do_train \ 152 | --warmup_proportion=0. \ 153 | --seed=42 154 | ``` 155 | 156 | After training, the model weights and evaluation results on **Dev** set would be saved in `$OUTPUT_DIR`. 157 | 158 | ## Evaluation 159 | 160 | #### Trigger Detection 161 | 162 | Run *run_td_softmax.py* with `--do_test`: 163 | 164 | ```bash 165 | cd src/trigger_detection/ 166 | 167 | export OUTPUT_DIR=./softmax_ce_results/ 168 | 169 | python3 run_td_softmax.py \ 170 | --output_dir=$OUTPUT_DIR \ 171 | --model_type=longformer \ 172 | --model_checkpoint=../../PT_MODELS/allenai/longformer-large-4096/ \ 173 | --train_file=../../data/train_filtered.json \ 174 | --dev_file=../../data/dev_filtered.json \ 175 | --test_file=../../data/test_filtered.json \ 176 | --max_seq_length=4096 \ 177 | --learning_rate=1e-5 \ 178 | --softmax_loss=ce \ 179 | --num_train_epochs=50 \ 180 | --batch_size=1 \ 181 | --do_test \ 182 | --warmup_proportion=0. \ 183 | --seed=42 184 | ``` 185 | 186 | After evaluation, the evaluation results on **Test** set would be saved in `$OUTPUT_DIR`. Use `--do_predict` parameter to predict subtype labels. The predicted results, i.e., `XXX_test_pred_events.json`, would be saved in `$OUTPUT_DIR`. 187 | 188 | #### Event Coreference 189 | 190 | Run *run_global_base_with_mask_topic.py* with `--do_test`: 191 | 192 | ```bash 193 | cd src/global_event_coref/ 194 | 195 | export OUTPUT_DIR=./MaskTopic_M-multi-cosine_results/ 196 | 197 | python3 run_global_base_with_mask_topic.py \ 198 | --output_dir=$OUTPUT_DIR \ 199 | --model_type=longformer \ 200 | --model_checkpoint=../../../PT_MODELS/allenai/longformer-large-4096/ \ 201 | --mention_encoder_type=bert \ 202 | --mention_encoder_checkpoint=../../../PT_MODELS/bert-base-cased/ \ 203 | --topic_model=vmf \ 204 | --topic_dim=32 \ 205 | --topic_inter_map=64 \ 206 | --train_file=../../data/train_filtered.json \ 207 | --dev_file=../../data/dev_filtered.json \ 208 | --test_file=../../data/test_filtered.json \ 209 | --max_seq_length=4096 \ 210 | --max_mention_length=256 \ 211 | --learning_rate=1e-5 \ 212 | --matching_style=multi_cosine \ 213 | --softmax_loss=ce \ 214 | --num_train_epochs=50 \ 215 | --batch_size=1 \ 216 | --do_test \ 217 | --warmup_proportion=0. \ 218 | --seed=42 219 | ``` 220 | 221 | After evaluation, the evaluation results on **Test** set would be saved in `$OUTPUT_DIR`. Use `--do_predict` parameter to predict coreferences for event mention pairs. The predicted results, i.e., `XXX_test_pred_corefs.json`, would be saved in `$OUTPUT_DIR`. 222 | 223 | #### Clustering 224 | 225 | Create the final event clusters using predicted pairwise results: 226 | 227 | ```bash 228 | cd src/clustering 229 | 230 | export OUTPUT_DIR=./TEMP/ 231 | 232 | python3 run_cluster.py \ 233 | --output_dir=$OUTPUT_DIR \ 234 | --test_golden_filepath=../../data/test.json \ 235 | --test_pred_filepath=../../data/XXX_weights.bin_test_pred_corefs.json \ 236 | --golden_conll_filename=gold_test.conll \ 237 | --pred_conll_filename=pred_test.conll \ 238 | --do_evaluate 239 | ``` 240 | 241 | ## Results 242 | 243 | #### Download Final Model 244 | 245 | You can download the final Trigger Detection & Event Coreference models at: 246 | 247 | [https://drive.google.com/drive/folders/182jll9UZ8yqQ93Dev92XDI0v2jhN7wcw?usp=sharing](https://drive.google.com/drive/folders/182jll9UZ8yqQ93Dev92XDI0v2jhN7wcw?usp=sharing) 248 | 249 | #### Trigger Detection 250 | 251 | | Model | Micro (P / R / F1) | Macro (P / R / F1) | 252 | | ------------------------------------------------------------ | :----------------: | :----------------: | 253 | | [(Lu & Ng, 2021)](https://aclanthology.org/2021.emnlp-main.103/) | 71.6 / 58.7 / 64.5 | - / - / - | 254 | | Longformer | 63.0 / 58.1 / 60.4 | 65.2 / 57.7 / 59.2 | 255 | | Longformer+CRF | 64.8 / 54.6 / 59.2 | 65.9 / 55.2 / 58.1 | 256 | 257 | #### Classical Pairwise Models 258 | 259 | | Model | Pairwise | MUC | B3 | CEA | BLA | AVG | 260 | | --------------------------- | :----------------: | :--: | :--: | :--: | :--: | :--: | 261 | | BERT-large[Prod] | 62.3 / 49.3 / 55.0 | 36.5 | 54.4 | 55.8 | 37.3 | 46.0 | 262 | | RoBERTa-large[Prod] | 64.6 / 44.0 / 52.4 | 36.0 | 54.8 | 55.6 | 37.3 | 45.9 | 263 | | BERT-large[Prod] + Local | 69.0 / 45.5 / 54.8 | 37.6 | 55.1 | 57.1 | 38.5 | 47.1 | 264 | | RoBERTa-large[Prod] + Local | 71.7 / 49.9 / 58.9 | 39.0 | 55.8 | 58.0 | 39.6 | 48.1 | 265 | 266 | #### Pairwise & Chunk Variants 267 | 268 | Replace Global Mention Encoder in our model with pairwise (sentence-level) encoder or chunk (segment-level) encoder. 269 | 270 | | Model | Pairwise | MUC | B3 | CEA | BLA | AVG | 271 | | ---------------------- | :----------------: | :--: | :--: | :--: | :--: | :--: | 272 | | BERT-base[Pairwise] | 64.0 / 39.8 / 49.0 | 35.3 | 54.4 | 55.8 | 36.6 | 45.5 | 273 | | RoBERTa-base[Pairwise] | 59.9 / 55.6 / 57.7 | 39.0 | 54.3 | 56.4 | 38.6 | 47.1 | 274 | | BERT-base[Chunk] | 59.7 / 50.6 / 54.7 | 38.4 | 54.9 | 55.4 | 37.9 | 46.7 | 275 | | RoBERTa-base[Chunk] | 64.0 / 51.3 / 56.9 | 39.6 | 55.2 | 56.9 | 38.5 | 47.6 | 276 | 277 | #### Our Model 278 | 279 | | Model | Pairwise | MUC | B3 | CEA | BLA | AVG | 280 | | ------------------------------------------------------------ | :----------------: | :--: | :--: | :--: | :--: | :--: | 281 | | [(Lu & Ng, 2021)](https://aclanthology.org/2021.emnlp-main.103/) | - | 45.2 | 54.7 | 53.8 | 38.2 | 48.0 | 282 | | Global | 74.7 / 63.2 / 68.4 | 45.4 | 57.3 | 58.7 | 42.2 | 50.9 | 283 | | + Local | 72.4 / 63.3 / 67.6 | 45.8 | 57.5 | 59.1 | 42.1 | 51.1 | 284 | | + Local & Topic | 72.0 / 64.4 / 68.0 | 46.2 | 57.4 | 59.0 | 42.0 | 51.2 | 285 | 286 | #### Variants using different tensor matching 287 | 288 | | Model | Pairwise | MUC | B3 | CEA | BLA | AVG | 289 | | ------------------ | :----------------: | :--: | :--: | :--: | :--: | :--: | 290 | | Base | 37.5 / 48.0 / 42.1 | 36.7 | 54.9 | 55.3 | 34.7 | 45.4 | 291 | | Base+Prod | 71.2 / 64.0 / 67.4 | 45.4 | 57.0 | 58.6 | 41.2 | 50.5 | 292 | | Base+Prod+Cos | 72.0 / 64.4 / 68.0 | 46.2 | 57.4 | 59.0 | 42.0 | 51.2 | 293 | | Base+Prod+Diff | 70.3 / 67.1 / 68.7 | 45.0 | 56.7 | 58.9 | 41.4 | 50.5 | 294 | | Base+Prod+Diff+Cos | 69.5 / 65.9 / 67.6 | 44.4 | 56.5 | 58.6 | 41.2 | 50.2 | 295 | 296 | ## Contact info 297 | 298 | Contact [Sheng Xu](https://github.com/jsksxs360) at *[sxu@stu.suda.edu.cn](mailto:sxu@stu.suda.edu.cn)* for questions about this repository. 299 | 300 | ``` 301 | @inproceedings{xu-etal-2022-improving, 302 | title = "Improving Event Coreference Resolution Using Document-level and Topic-level Information", 303 | author = "Xu, Sheng and 304 | Li, Peifeng and 305 | Zhu, Qiaoming", 306 | booktitle = "Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing", 307 | month = dec, 308 | year = "2022", 309 | address = "Abu Dhabi, United Arab Emirates", 310 | publisher = "Association for Computational Linguistics", 311 | url = "https://aclanthology.org/2022.emnlp-main.454", 312 | pages = "6765--6775" 313 | } 314 | ``` 315 | -------------------------------------------------------------------------------- /data/SplitSentences/.classpath: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /data/SplitSentences/.project: -------------------------------------------------------------------------------- 1 | 2 | 3 | SplitSentences 4 | 5 | 6 | 7 | 8 | 9 | org.eclipse.jdt.core.javabuilder 10 | 11 | 12 | 13 | 14 | 15 | org.eclipse.jdt.core.javanature 16 | 17 | 18 | -------------------------------------------------------------------------------- /data/SplitSentences/.settings/org.eclipse.core.resources.prefs: -------------------------------------------------------------------------------- 1 | eclipse.preferences.version=1 2 | encoding/=UTF-8 3 | -------------------------------------------------------------------------------- /data/SplitSentences/.settings/org.eclipse.core.runtime.prefs: -------------------------------------------------------------------------------- 1 | eclipse.preferences.version=1 2 | line.separator=\n 3 | -------------------------------------------------------------------------------- /data/SplitSentences/.settings/org.eclipse.jdt.core.prefs: -------------------------------------------------------------------------------- 1 | eclipse.preferences.version=1 2 | org.eclipse.jdt.core.compiler.codegen.inlineJsrBytecode=enabled 3 | org.eclipse.jdt.core.compiler.codegen.targetPlatform=1.8 4 | org.eclipse.jdt.core.compiler.codegen.unusedLocal=preserve 5 | org.eclipse.jdt.core.compiler.compliance=1.8 6 | org.eclipse.jdt.core.compiler.debug.lineNumber=generate 7 | org.eclipse.jdt.core.compiler.debug.localVariable=generate 8 | org.eclipse.jdt.core.compiler.debug.sourceFile=generate 9 | org.eclipse.jdt.core.compiler.problem.assertIdentifier=error 10 | org.eclipse.jdt.core.compiler.problem.enumIdentifier=error 11 | org.eclipse.jdt.core.compiler.source=1.8 12 | -------------------------------------------------------------------------------- /data/SplitSentences/bin/test/Main.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsksxs360/event-coref-emnlp2022/f6b82beee0558b187e63f50cf3064fc9ef9d39c8/data/SplitSentences/bin/test/Main.class -------------------------------------------------------------------------------- /data/SplitSentences/bin/test/Sentence.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsksxs360/event-coref-emnlp2022/f6b82beee0558b187e63f50cf3064fc9ef9d39c8/data/SplitSentences/bin/test/Sentence.class -------------------------------------------------------------------------------- /data/SplitSentences/bin/test/SourceReader.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsksxs360/event-coref-emnlp2022/f6b82beee0558b187e63f50cf3064fc9ef9d39c8/data/SplitSentences/bin/test/SourceReader.class -------------------------------------------------------------------------------- /data/SplitSentences/bin/word/WordCount.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsksxs360/event-coref-emnlp2022/f6b82beee0558b187e63f50cf3064fc9ef9d39c8/data/SplitSentences/bin/word/WordCount.class -------------------------------------------------------------------------------- /data/SplitSentences/readme.md: -------------------------------------------------------------------------------- 1 | ### Split sentences and Count verb/entity numbers 2 | 3 | 1. Create *SplitSentences/lib* folder if not exist. 4 | 2. Download **CoreNLP X.X.X** and **English (KBP)** model jar from [CoreNLP](https://stanfordnlp.github.io/CoreNLP/index.html#quickstart). 5 | Unzip **CoreNLP X.X.X**, move `slf4j-api.jar`, `slf4j-simple.jar`, `stanford-corenlp-x.x.x.jar`, `stanford-corenlp-x.x.x-models.jar`, and ``stanford-corenlp-models-english-kbp.jar`` to the *SplitSentences/lib* folder. 6 | 3. Download [**Gson**](http://www.java2s.com/example/jar/g/gson-index.html) jar and move `gson-x.x.x.jar` to the *SplitSentences/lib* folder. 7 | 4. Run `Main.java` and `WordCount.java`. 8 | 9 | -------------------------------------------------------------------------------- /data/SplitSentences/src/test/Main.java: -------------------------------------------------------------------------------- 1 | package test; 2 | 3 | import java.io.BufferedWriter; 4 | import java.io.FileWriter; 5 | import java.io.IOException; 6 | import java.util.Arrays; 7 | import java.util.LinkedList; 8 | import java.util.List; 9 | import java.util.regex.Matcher; 10 | import java.util.regex.Pattern; 11 | 12 | public class Main { 13 | 14 | public static void main(String[] args) { 15 | 16 | String LDC2015E29 = "../LDC_TAC_KBP/LDC2015E29/data/source/mpdfxml/"; 17 | String LDC2015E68 = "../LDC_TAC_KBP/LDC2015E68/data/source/"; 18 | 19 | String KBP2015Train = "../LDC_TAC_KBP/LDC2017E02/data/2015/training/source/"; 20 | String KBP2015Eval = "../LDC_TAC_KBP/LDC2017E02/data/2015/eval/source/"; 21 | 22 | String KBP2016EvalNW = "../LDC_TAC_KBP/LDC2017E02/data/2016/eval/eng/nw/source/"; 23 | String KBP2016EvalDF = "../LDC_TAC_KBP/LDC2017E02/data/2016/eval/eng/df/source/"; 24 | 25 | String KBP2017EvalNW = "../LDC_TAC_KBP/LDC2017E54/data/eng/nw/source/"; 26 | String KBP2017EvalDF = "../LDC_TAC_KBP/LDC2017E54/data/eng/df/source/"; 27 | 28 | SourceReader reader = new SourceReader(); 29 | List KBPSents = new LinkedList<>(); 30 | try { 31 | // LDC2015E29 32 | List LDC2015E29Sents = reader.readSourceFolder(LDC2015E29); 33 | System.out.println("LDC2015E29: " + LDC2015E29Sents.size()); 34 | KBPSents.addAll(LDC2015E29Sents); 35 | // LDC2015E68 36 | List LDC2015E68Sents = reader.readSourceFolder(LDC2015E68); 37 | System.out.println("LDC2015E68: " + LDC2015E68Sents.size()); 38 | KBPSents.addAll(LDC2015E68Sents); 39 | // KBP2015 40 | List LDC2015TrainSents = reader.readSourceFolder(KBP2015Train); 41 | List LDC2015EvalSents = reader.readSourceFolder(KBP2015Eval); 42 | System.out.println("LDC2015: " + (LDC2015TrainSents.size() + LDC2015EvalSents.size())); 43 | KBPSents.addAll(LDC2015TrainSents); 44 | KBPSents.addAll(LDC2015EvalSents); 45 | // KBP 2016 46 | List KBP2016EvalNWSents = reader.readSourceFolder(KBP2016EvalNW, false); 47 | List KBP2016EvalDFSents = reader.readSourceFolder(KBP2016EvalDF, true); 48 | System.out.println("KBP2016: " + (KBP2016EvalNWSents.size() + KBP2016EvalDFSents.size())); 49 | KBPSents.addAll(KBP2016EvalNWSents); 50 | KBPSents.addAll(KBP2016EvalDFSents); 51 | // KBP 2017 52 | List KBP2017EvalNWSents = reader.readSourceFolder(KBP2017EvalNW, false); 53 | List KBP2017EvalDFSents = reader.readSourceFolder(KBP2017EvalDF, true); 54 | System.out.println("KBP2017: " + (KBP2017EvalNWSents.size() + KBP2017EvalDFSents.size())); 55 | KBPSents.addAll(KBP2017EvalNWSents); 56 | KBPSents.addAll(KBP2017EvalDFSents); 57 | } catch (Exception e) { 58 | e.printStackTrace(); 59 | } 60 | try { 61 | saveFile("../kbp_sent.txt", KBPSents); 62 | } catch (IOException e) { 63 | e.printStackTrace(); 64 | } 65 | } 66 | 67 | public static void saveFile(String filename, List sents) throws IOException { 68 | BufferedWriter writer = new BufferedWriter(new FileWriter(filename)); 69 | for (Sentence sent : sents) { 70 | String text = sent.text.replace("\t", " "); 71 | if (isContainChinese(text) || text.startsWith("http") || text.startsWith("www.") || filter(text)) { 72 | continue; 73 | } 74 | writer.write(sent.filename + "\t" + sent.start + "\t" + text + "\n"); 75 | } 76 | writer.close(); 77 | } 78 | 79 | public static boolean filter(String str) { 80 | List stopwords = Arrays.asList("P.S.", "PS", "snip", 81 | "&", "<", ">", " ", """, 82 | "#", "*", ".", "/", "year", "day", "month", "Â", "-", "[", "]", 83 | "!", "?", ",", ";", "(", ")", ":", "~", "_", 84 | "cof", "sigh", "shrug", "and", "or", "done", "URL"); 85 | for (String w : stopwords) { 86 | str = str.replace(w, " "); 87 | } 88 | Pattern p = Pattern.compile("[0-9]"); 89 | Matcher matcher = p.matcher(str); 90 | str = matcher.replaceAll(" "); 91 | if (str.trim().isEmpty() || str.trim().length() == 1) return true; 92 | return false; 93 | } 94 | 95 | public static boolean isContainChinese(String str) { 96 | Pattern p = Pattern.compile("[\u4E00-\u9FA5]"); 97 | Matcher m = p.matcher(str); 98 | if (m.find()) { 99 | return true; 100 | } 101 | return false; 102 | } 103 | 104 | } 105 | -------------------------------------------------------------------------------- /data/SplitSentences/src/test/Sentence.java: -------------------------------------------------------------------------------- 1 | package test; 2 | 3 | public class Sentence { 4 | public String filename; 5 | public String text; 6 | public int start; 7 | 8 | public Sentence(String filename, String text, int start) { 9 | this.filename = filename; 10 | this.text = text; 11 | this.start = start; 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /data/SplitSentences/src/test/SourceReader.java: -------------------------------------------------------------------------------- 1 | package test; 2 | 3 | import java.io.BufferedReader; 4 | import java.io.File; 5 | import java.io.FileReader; 6 | import java.util.Arrays; 7 | import java.util.LinkedList; 8 | import java.util.List; 9 | import java.util.Properties; 10 | import java.util.regex.Matcher; 11 | import java.util.regex.Pattern; 12 | 13 | import edu.stanford.nlp.pipeline.CoreDocument; 14 | import edu.stanford.nlp.pipeline.CoreSentence; 15 | import edu.stanford.nlp.pipeline.StanfordCoreNLP; 16 | import edu.stanford.nlp.util.StringUtils; 17 | 18 | public class SourceReader { 19 | 20 | private StanfordCoreNLP pipeline; 21 | private List newsStart = Arrays.asList(new String[]{"AFP", "APW", "CNA", "NYT", "WPB", "XIN"}); 22 | 23 | public SourceReader() { 24 | Properties props = new Properties(); 25 | props.setProperty("annotators", "tokenize,ssplit"); 26 | pipeline = new StanfordCoreNLP(props); 27 | } 28 | 29 | public List readSourceFolder(String folderPath, boolean isDF) throws Exception { 30 | File folder = new File(folderPath); 31 | List results = new LinkedList<>(); 32 | for (String file : folder.list()) { 33 | results.addAll(readSourceFile(folderPath + file, isDF)); 34 | } 35 | return results; 36 | } 37 | 38 | public List readSourceFolder(String folderPath) throws Exception { 39 | File folder = new File(folderPath); 40 | List results = new LinkedList<>(); 41 | for (String file : folder.list()) { 42 | if (newsStart.contains(file.substring(0, 3))) { // News 43 | results.addAll(readSourceFile(folderPath + file, false)); 44 | } else { // Forum 45 | results.addAll(readSourceFile(folderPath + file, true)); 46 | } 47 | } 48 | return results; 49 | } 50 | 51 | public List readSourceFile(String filePath, boolean isDF) throws Exception { 52 | if (isDF) { // Forum 53 | return forumArticleReader(filePath, this.pipeline); 54 | } else { // News 55 | return newsArticleReader(filePath, this.pipeline); 56 | } 57 | } 58 | 59 | private static List forumArticleReader(String filePath, StanfordCoreNLP model) throws Exception { 60 | BufferedReader br = new BufferedReader(new FileReader(filePath)); 61 | String filename = new File(filePath).getName(); 62 | List filters = Arrays.asList(".txt", ".xml", ".mpdf", ".cmp"); 63 | for (String w : filters) { 64 | filename = filename.replace(w, ""); 65 | } 66 | String line; 67 | int start = 0; 68 | List results = new LinkedList<>(); 69 | while ((line = br.readLine()) != null) { 70 | int length = line.length() + 1; 71 | line = line.trim(); 72 | if (line.startsWith("") || line.startsWith("")) { 73 | start += length; 74 | continue; 75 | } 76 | List sents = splitSentences(filename, line, start, model); 77 | results.addAll(sents); 78 | start += length; 79 | } 80 | br.close(); 81 | return results; 82 | } 83 | 84 | private static List newsArticleReader(String filePath, StanfordCoreNLP model) throws Exception { 85 | BufferedReader br = new BufferedReader(new FileReader(filePath)); 86 | String filename = new File(filePath).getName(); 87 | List filters = Arrays.asList(".txt", ".xml", ".mpdf", ".cmp"); 88 | for (String w : filters) { 89 | filename = filename.replace(w, ""); 90 | } 91 | String line; 92 | String Flag = ""; 93 | String text = ""; 94 | int start = 0; 95 | List results = new LinkedList<>(); 96 | while ((line = br.readLine()) != null) { 97 | int length = line.length() + 1; 98 | if (line.trim().equals("")) { 99 | Flag = "TEXT"; 100 | start += length; 101 | continue; 102 | } else if (line.trim().equals("") || line.trim().equals("

") || line.trim().equals("")) { 103 | Flag = "PARA"; 104 | start += length; 105 | continue; 106 | } else if (line.trim().equals("") || line.trim().equals("

") || line.trim().equals("")) { 107 | Flag = ""; 108 | List sentences = splitSentences(filename, text, start, model); 109 | results.addAll(sentences); 110 | start += text.length() + length; 111 | text = ""; 112 | continue; 113 | } else if (line.trim().equals("
")) { 114 | Flag = ""; 115 | start += length; 116 | text = ""; 117 | continue; 118 | } 119 | if (Flag.equals("PARA")) { 120 | text += line + " "; 121 | continue; 122 | } else if (Flag.equals("TEXT")) { 123 | List sentences = splitSentences(filename, line, start, model); 124 | results.addAll(sentences); 125 | start += length; 126 | continue; 127 | } 128 | start += length; 129 | } 130 | br.close(); 131 | return results; 132 | } 133 | 134 | private static List splitSentences(String filename, String text, int start, StanfordCoreNLP pipeline) throws Exception { 135 | if (text.contains("<")) { // html file 136 | Pattern p_html = Pattern.compile("<[^>]+>", Pattern.CASE_INSENSITIVE); 137 | Matcher m_html = p_html.matcher(text); 138 | StringBuffer sb = new StringBuffer(); 139 | while (m_html.find()) { 140 | m_html.appendReplacement(sb, StringUtils.repeat(" ", m_html.group().length())); 141 | } 142 | m_html.appendTail(sb); 143 | text = sb.toString(); 144 | int count = 0; 145 | if (text.startsWith(" ")) { 146 | for (int i = 0; i < text.length(); i++) { 147 | if (text.charAt(i) != ' ') { break; } 148 | count += 1; 149 | } 150 | } 151 | text = text.trim(); 152 | start += count; 153 | } 154 | // split sentence 155 | CoreDocument doc = new CoreDocument(text); 156 | pipeline.annotate(doc); 157 | List results = new LinkedList<>(); 158 | for (CoreSentence sent : doc.sentences()) { 159 | Integer sentOffset = sent.charOffsets().first; 160 | String sentText = sent.text(); 161 | if (sentText.isEmpty() || sentText.length() < 3) continue; 162 | results.add(new Sentence(filename, sentText, start + sentOffset)); 163 | } 164 | return results; 165 | } 166 | } 167 | -------------------------------------------------------------------------------- /data/SplitSentences/src/word/WordCount.java: -------------------------------------------------------------------------------- 1 | package word; 2 | 3 | import java.io.BufferedReader; 4 | import java.io.BufferedWriter; 5 | import java.io.FileNotFoundException; 6 | import java.io.FileReader; 7 | import java.io.FileWriter; 8 | import java.io.IOException; 9 | import java.util.Arrays; 10 | import java.util.HashMap; 11 | import java.util.List; 12 | import java.util.Map; 13 | import java.util.Properties; 14 | 15 | import com.google.gson.Gson; 16 | 17 | import edu.stanford.nlp.ling.CoreLabel; 18 | import edu.stanford.nlp.pipeline.CoreDocument; 19 | import edu.stanford.nlp.pipeline.CoreEntityMention; 20 | import edu.stanford.nlp.pipeline.StanfordCoreNLP; 21 | 22 | public class WordCount { 23 | 24 | private StanfordCoreNLP pipeline; 25 | List entityType = Arrays.asList("PERSON", "LOCATION", "ORGANIZATION"); 26 | List stopwords = Arrays.asList( 27 | "a", "an", "and", "are", "as", "at", "be", "but", "by", 28 | "for", "if", "in", "into", "is", "it", "been", 29 | "no", "not", "of", "on", "or", "such", 30 | "that", "the", "their", "then", "there", "these", 31 | "they", "this", "to", "was", "will", "with", 32 | "he", "she", "his", "her", "were", "do" 33 | ); 34 | public WordCount() { 35 | Properties props = new Properties(); 36 | props.setProperty("annotators", "tokenize,ssplit,pos,lemma,ner"); 37 | props.setProperty("ner.applyFineGrained", "false"); 38 | props.setProperty("ner.applyNumericClassifiers", "false"); 39 | pipeline = new StanfordCoreNLP(props); 40 | } 41 | 42 | public Map getVerbEntity(String document) { 43 | CoreDocument doc = new CoreDocument(document); 44 | Map wordStatistic = new HashMap(); 45 | this.pipeline.annotate(doc); 46 | for (CoreLabel tok : doc.tokens()) { 47 | String word = tok.word().toLowerCase(); 48 | if (this.stopwords.contains(word)) continue; 49 | if (tok.tag().startsWith("VB")) { 50 | wordStatistic.put(word, wordStatistic.getOrDefault(word, 0) + 1); 51 | } 52 | } 53 | for (CoreEntityMention em : doc.entityMentions()) { 54 | String entity = em.text().toLowerCase(); 55 | if (this.stopwords.contains(entity) || !this.entityType.contains(em.entityType())) continue; 56 | wordStatistic.put(entity, wordStatistic.getOrDefault(entity, 0) + 1); 57 | } 58 | return wordStatistic; 59 | } 60 | 61 | public static void main(String[] args) throws IOException { 62 | String kbp_sent_filePath = "../kbp_sent.txt"; 63 | BufferedReader br = new BufferedReader(new FileReader(kbp_sent_filePath)); 64 | String line; 65 | Map kbp_documents = new HashMap(); 66 | while ((line = br.readLine()) != null) { 67 | String[] items = line.trim().split("\t"); 68 | if (kbp_documents.containsKey(items[0])) { 69 | kbp_documents.replace(items[0], kbp_documents.get(items[0]) + " " + items[2]); 70 | } else { 71 | kbp_documents.put(items[0], items[2]); 72 | } 73 | } 74 | br.close(); 75 | WordCount wc = new WordCount(); 76 | Gson gson = new Gson(); 77 | BufferedWriter bw = new BufferedWriter(new FileWriter("../kbp_word_count.txt")); 78 | for (Map.Entry entry : kbp_documents.entrySet()) { 79 | System.out.println(entry.getKey()); 80 | String countStr = gson.toJson(wc.getVerbEntity(entry.getValue())); 81 | bw.write(entry.getKey() + "\t" + countStr + "\n"); 82 | } 83 | bw.close(); 84 | } 85 | 86 | } 87 | -------------------------------------------------------------------------------- /data/convert.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from collections import namedtuple 3 | import xml.etree.ElementTree as ET 4 | import os 5 | import re 6 | from typing import Dict, List, Tuple 7 | import logging 8 | import json 9 | import numpy as np 10 | from itertools import combinations 11 | import argparse 12 | from utils import print_data_statistic, filter_events, check_event_conflict 13 | 14 | parser = argparse.ArgumentParser() 15 | 16 | parser.add_argument("--kbp_data_dir", default='LDC_TAC_KBP/', type=str) 17 | parser.add_argument("--sent_data_dir", default='./', type=str) 18 | args = parser.parse_args() 19 | 20 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 21 | datefmt='%Y/%m/%d %H:%M:%S', 22 | level=logging.INFO) 23 | 24 | logger = logging.getLogger("Convert") 25 | 26 | SENT_FILE = 'kbp_sent.txt' 27 | DATA_DIRS = { 28 | '2015': [ 29 | 'LDC2015E29/data/ere/mpdfxml', 30 | 'LDC2015E68/data/ere', 31 | 'LDC2017E02/data/2015/training/event_hopper', 32 | 'LDC2017E02/data/2015/eval/hopper' 33 | ], 34 | '2016': [ 35 | 'LDC2017E02/data/2016/eval/eng/nw/ere', 36 | 'LDC2017E02/data/2016/eval/eng/df/ere' 37 | ], 38 | '2017': [ 39 | 'LDC2017E54/data/eng/nw/ere', 40 | 'LDC2017E54/data/eng/df/ere' 41 | ] 42 | } 43 | 44 | Sentence = namedtuple("Sentence", ["start", "text"]) 45 | Filename = namedtuple("Filename", ["doc_id", "file_path"]) 46 | 47 | def get_KBP_sents(sent_file_path:str) -> Dict[str, List[Sentence]]: 48 | '''get sentences in the KBP dataset 49 | # Returns: 50 | - sentence dictionary: {filename: [Sentence]} 51 | ''' 52 | sent_dic = collections.defaultdict(list) 53 | with open(sent_file_path, 'rt', encoding='utf-8') as sents: 54 | for line in sents: 55 | doc_id, start, text = line.strip().split('\t') 56 | sent_dic[doc_id].append(Sentence(int(start), text)) 57 | for sents in sent_dic.values(): 58 | sents.sort(key=lambda x:x.start) 59 | return sent_dic 60 | 61 | def get_KBP_filenames(version:str) -> List[Filename]: 62 | '''get KBP filenames 63 | # Args: 64 | - version: 2015 / 2016 / 2017 65 | # Return: 66 | - filename list: [Filename] 67 | ''' 68 | assert version in ['2015', '2016', '2017'] 69 | filename_list = [] 70 | for folder in DATA_DIRS[version]: 71 | filename_list += [ 72 | Filename( 73 | re.sub('\.event_hoppers\.xml|\.rich_ere\.xml', '', filename), 74 | os.path.join(folder, filename) 75 | ) for filename in os.listdir(os.path.join(args.kbp_data_dir, folder)) 76 | ] 77 | return filename_list 78 | 79 | def create_new_document(sent_list:List[Sentence]) -> str: 80 | '''create new source document 81 | ''' 82 | document = '' 83 | end = 0 84 | for sent in sent_list: 85 | assert sent.start >= end 86 | document += ' ' * (sent.start - end) 87 | document += sent.text 88 | end = sent.start + len(sent.text) 89 | for sent in sent_list: # check 90 | assert document[sent.start:sent.start+len(sent.text)] == sent.text 91 | return document 92 | 93 | def find_event_sent(doc_id, event_start, trigger, sent_list) -> Tuple[int, int]: 94 | '''find out which sentence the event come from 95 | ''' 96 | for idx, sent in enumerate(sent_list): 97 | s_start, s_end = sent.start, sent.start + len(sent.text) - 1 98 | if s_start <= event_start <= s_end: 99 | e_s_start = event_start - s_start 100 | assert sent.text[e_s_start:e_s_start+len(trigger)] == trigger 101 | return idx, event_start - s_start 102 | print(doc_id) 103 | print(event_start, trigger, '\n') 104 | for sent in sent_list: 105 | print(sent.start, sent.start + len(sent.text) - 1) 106 | return None 107 | 108 | def update_trigger(text, trigger, offset): 109 | punc_set = set('#$%&+=@.,;!?*\\~\'\n\r\t()[]|/’-:{<>}、"。,?“”') 110 | new_trigger = trigger 111 | if offset + len(trigger) < len(text) and text[offset + len(trigger)] != ' ' and text[offset + len(trigger)] not in punc_set: 112 | for c in text[offset + len(trigger):]: 113 | if c == ' ' or c in punc_set: 114 | break 115 | new_trigger += c 116 | new_trigger = new_trigger.strip('\n\r\t') 117 | new_trigger = new_trigger.strip(u'\x94') 118 | if new_trigger != trigger: 119 | logger.warning(f'update: [{trigger}]({len(trigger)}) - [{new_trigger}]({len(new_trigger)})') 120 | return new_trigger 121 | 122 | def xml_parser(file_path:str, sent_list:List[Sentence]) -> Dict: 123 | '''KBP datafile XML parser 124 | # Args: 125 | - file_path: xml file path 126 | - sent_list: Sentences of file 127 | ''' 128 | tree = ET.ElementTree(file=file_path) 129 | doc_id = re.sub('\.event_hoppers\.xml|\.rich_ere\.xml', '', os.path.split(file_path)[1]) 130 | document = create_new_document(sent_list) 131 | sentence_list = [{'start': sent.start, 'text': sent.text} for sent in sent_list] 132 | event_list = [] 133 | cluster_list = [] 134 | for hopper in tree.iter(tag='hopper'): 135 | h_id = hopper.attrib['id'] # hopper id 136 | h_events = [] 137 | for event in hopper.iter(tag='event_mention'): 138 | att = event.attrib 139 | e_id = att['id'] 140 | e_type, e_subtype, e_realis = att['type'], att['subtype'], att['realis'] 141 | e_trigger = event.find('trigger').text.strip() 142 | e_start = int(event.find('trigger').attrib['offset']) 143 | e_s_index, e_s_start = find_event_sent(doc_id, e_start, e_trigger, sent_list) 144 | e_trigger = update_trigger(sent_list[e_s_index].text, e_trigger, e_s_start) 145 | event_list.append({ 146 | 'event_id': e_id, 147 | 'start': e_start, 148 | 'trigger': e_trigger, 149 | 'type': e_type, 150 | 'subtype': e_subtype, 151 | 'realis': e_realis, 152 | 'sent_idx': e_s_index, 153 | 'sent_start': e_s_start 154 | }) 155 | h_events.append(e_id) 156 | cluster_list.append({ 157 | 'hopper_id': h_id, 158 | 'events': h_events 159 | }) 160 | return { 161 | 'doc_id': doc_id, 162 | 'document': document, 163 | 'sentences': sentence_list, 164 | 'events': event_list, 165 | 'clusters': cluster_list 166 | } 167 | 168 | def split_dev(doc_list:list, valid_doc_num:int, valid_event_num:int, valid_chain_num:int): 169 | '''split dev set from full train set 170 | ''' 171 | docs_id = [doc['doc_id'] for doc in doc_list] 172 | docs_event_num = np.asarray([len(doc['events']) for doc in doc_list]) 173 | docs_event_num[docs_id.index('bolt-eng-DF-170-181109-47916')] += 2 174 | docs_event_num[docs_id.index('bolt-eng-DF-170-181109-48534')] += 1 175 | docs_cluster_num = np.asarray([len(doc['clusters']) for doc in doc_list]) 176 | logger.info(f'Train & Dev set: Doc: {len(docs_id)} | Event: {docs_event_num.sum()} | Cluster: {docs_cluster_num.sum()}') 177 | train_docs, dev_docs = [], [] 178 | logger.info(f'finding the correct split...') 179 | for indexs in combinations(range(len(docs_id)), valid_doc_num): 180 | indexs = np.asarray(indexs) 181 | if ( 182 | docs_event_num[indexs].sum() == valid_event_num and 183 | docs_cluster_num[indexs].sum() == valid_chain_num 184 | ): 185 | logger.info(f'Done!') 186 | for idx, doc in enumerate(doc_list): 187 | if idx in indexs: 188 | dev_docs.append(doc) 189 | else: 190 | train_docs.append(doc) 191 | break 192 | return train_docs, dev_docs 193 | 194 | if __name__ == "__main__": 195 | docs = collections.defaultdict(list) 196 | kbp_sent_list = get_KBP_sents(os.path.join(args.sent_data_dir, SENT_FILE)) 197 | for dataset in ['2015', '2016', '2017']: 198 | logger.info(f"parsing xml files in KBP {dataset} ...") 199 | for filename in get_KBP_filenames(dataset): 200 | doc_results = xml_parser(os.path.join(args.kbp_data_dir, filename.file_path), kbp_sent_list[filename.doc_id]) 201 | docs[f'kbp_{dataset}'].append(doc_results) 202 | logger.info(f"Finished!") 203 | print_data_statistic(docs[f'kbp_{dataset}'], dataset) 204 | # split Dev set 205 | train_docs, dev_docs = split_dev(docs['kbp_2015'] + docs['kbp_2016'], 82, 2382, 1502) 206 | kbp_dataset = { 207 | 'train': train_docs, 208 | 'dev': dev_docs, 209 | 'test': docs['kbp_2017'] 210 | } 211 | for doc_list in kbp_dataset.values(): 212 | doc_list.sort(key=lambda x:x['doc_id']) 213 | for dataset in ['train', 'dev', 'test']: 214 | logger.info(f"saving {dataset} set ...") 215 | dataset_doc_list = kbp_dataset[dataset] 216 | print_data_statistic(dataset_doc_list, dataset) 217 | with open(f'{dataset}.json', 'wt', encoding='utf-8') as f: 218 | for doc in dataset_doc_list: 219 | f.write(json.dumps(doc) + '\n') 220 | logger.info(f"Finished!") 221 | # filter events & clusters 222 | for dataset in ['train', 'dev', 'test']: 223 | dataset_doc_list = filter_events(kbp_dataset[dataset], dataset) 224 | check_event_conflict(dataset_doc_list) 225 | print_data_statistic(dataset_doc_list, dataset) 226 | logger.info(f"saving filtered {dataset} set ...") 227 | with open(f'{dataset}_filtered.json', 'wt', encoding='utf-8') as f: 228 | for doc in dataset_doc_list: 229 | f.write(json.dumps(doc) + '\n') 230 | logger.info(f"Finished!") 231 | -------------------------------------------------------------------------------- /data/kbp_vocab.txt: -------------------------------------------------------------------------------- 1 | ["have", "said", "has", "'s", "had", "did", "him", "think", "get", "know", "does", "'m", "see", "being", "going", "say", "go", "make", "made", "got", "want", "am", "take", "'re", "us", "apple", "told", "china", "obama", "u.s.", "'ve", "need", "according", "including", "let", "called", "used", "give", "pay", "saying", "went", "keep", "come", "believe", "killed", "convicted", "bush", "found", "left", "read", "find", "done", "came", "seems", "put", "having", "took", "doing", "work", "says", "use", "given", "united states", "feel", "thought", "arrested", "trying", "syria", "getting", "uk", "look", "making", "happened", "buy", "help", "elected", "israel", "posted", "tell", "hope", "agree", "wanted", "like", "become", "makes", "stop", "charged", "seen", "start", "known", "sent", "asked", "reported", "using", "russia", "allowed", "set", "understand", "involved", "mean", "based", "started", "taken", "died", "try", "run", "paid", "mandela", "live", "iraq", "held", "iran", "america", "heard", "saw", "looking", "happen", "working", "support", "taking", "leave", "expected", "released", "seem", "tried", "xinhua", "wants", "washington", "sentenced", "appear", "nokia", "coming", "remember", "announced", "guess", "eu", "love", "worked", "gave", "accused", "lost", "pakistan", "gets", "goes", "continue", "microsoft", "stay", "began", "served", "fired", "met", "became", "ask", "pardoned", "talking", "living", "gone", "egypt", "india", "hit", "move", "brought", "filed", "care", "following", "kill", "hear", "turned", "knew", "decided", "agreed", "needs", "vote", "snowden", "comes", "call", "allow", "led", "new york", "new york times", "sounds", "received", "knows", "leaving", "bring", "running", "killing", "wait", "north korea", "send", "talk", "die", "meet", "needed", "thinking", "cyprus", "means", "win", "congress", "show", "caused", "born", "explain", "clinton", "calling", "added", "wish", "giving", "florida", "london", "bought", "ordered", "europe", "issued", "stand", "spend", "thank", "shot", "google", "hold", "senate", "forced", "provide", "white house", "seeing", "wonder", "change", "related", "married", "claimed", "end", "texas", "takes", "philippines", "spent", "deal", "moved", "wrote", "foxconn", "france", "looks", "cut", "telling", "un", "follow", "morsi", "sell", "turn", "speak", "paying", "passed", "condensed", "won", "scotland", "usa", "nelson mandela", "showed", "committed", "face", "appointed", "injured", "ruled", "barack obama", "britain", "ukraine", "lose", "buying", "hate", "fighting", "granted", "denied", "confirmed", "germany", "happens", "kept", "considered", "sandusky", "chun", "failed", "remain", "helped", "works", "serve", "gives", "'d", "affected", "caught", "afghanistan", "changed", "leading", "receive", "owned", "cause", "created", "supreme court", "starting", "lead", "believed", "thinks", "sold", "japan", "trump", "followed", "mention", "supposed", "expect", "serving", "realize", "planned", "detained", "refused", "return", "remains", "felt", "extradited", "protect", "doubt", "blame", "moving", "built", "appeared", "offered", "reached", "include", "speaking", "waiting", "carry", "declined", "included", "bangladesh", "istanbul", "create", "spain", "california", "beijing", "decide", "supporting", "helping", "considering", "provided", "lived", "avoid", "rejected", "attacked", "arrived", "executed", "described", "named", "voted", "stopped", "claiming", "vietnam", "shows", "appears", "ended", "check", "returned", "discuss", "asking", "australia", "learn", "bet", "knowing", "imagine", "adding", "prove", "deserve", "remained", "consider", "putting", "save", "watch", "continued", "fight", "broke", "accept", "expressed", "ran", "join", "selling", "supported", "wanting", "written", "watching", "suggest", "pardon", "holding", "address", "cairo", "claim", "learned", "turkey", "pass", "played", "united nations", "carried", "signed", "seeking", "building", "grow", "libya", "south africa", "prevent", "losing", "assad", "treated", "meant", "scheduled", "driving", "happening", "add", "lying", "sound", "disagree", "raped", "travel", "mexico", "joined", "becoming", "stated", "required", "defend", "canada", "seemed", "haiyan", "afford", "attempted", "samsung", "gotten", "hired", "steve", "italy", "nominated", "register", "cost", "planning", "shown", "chang", "sought", "putin", "missing", "represent", "cover", "army", "published", "declared", "intended", "worry", "admitted", "looked", "growing", "sending", "urged", "ignore", "steve jobs", "moscow", "quoted", "raised", "post", "build", "zimmerman", "muslim brotherhood", "brazil", "reporting", "fined", "retired", "demanding", "gm", "covered", "clicking", "reading", "begin", "forget", "acting", "hoping", "resigned", "sitting", "argue", "compared", "sued", "fuck", "breaking", "drive", "cia", "dropped", "walk", "admit", "middle east", "ensure", "walked", "england", "watched", "beat", "force", "opposed", "south korea", "lives", "pick", "regarding", "justice department"] -------------------------------------------------------------------------------- /data/run_info.log: -------------------------------------------------------------------------------- 1 | 2022/04/09 04:50:42 - INFO - Convert - parsing xml files in KBP 2015 ... 2 | 2022/04/09 04:50:42 - WARNING - Convert - update: [EX-](3) - [EX-SOUTH](8) 3 | 2022/04/09 04:50:42 - WARNING - Convert - update: [manufacture](11) - [manufacturer](12) 4 | 2022/04/09 04:50:43 - WARNING - Convert - update: [explanation](11) - [explanations](12) 5 | 2022/04/09 04:50:43 - WARNING - Convert - update: [explanation](11) - [explanations](12) 6 | 2022/04/09 04:50:43 - WARNING - Convert - update: [explanation](11) - [explanations](12) 7 | 2022/04/09 04:50:43 - INFO - Convert - Finished! 8 | 2022/04/09 04:50:43 - INFO - Utils - KBP 2015 - Doc: 648 | Event: 18736 | Cluster: 11603 | Singleton: 8484 9 | 2022/04/09 04:50:43 - INFO - Convert - parsing xml files in KBP 2016 ... 10 | 2022/04/09 04:50:44 - INFO - Convert - Finished! 11 | 2022/04/09 04:50:44 - INFO - Utils - KBP 2016 - Doc: 169 | Event: 4155 | Cluster: 3191 | Singleton: 2709 12 | 2022/04/09 04:50:44 - INFO - Convert - parsing xml files in KBP 2017 ... 13 | 2022/04/09 04:50:44 - WARNING - Convert - update: [-](1) - [-1996](5) 14 | 2022/04/09 04:50:45 - INFO - Convert - Finished! 15 | 2022/04/09 04:50:45 - INFO - Utils - KBP 2017 - Doc: 167 | Event: 4375 | Cluster: 2963 | Singleton: 2358 16 | 17 | 2022/04/09 04:50:45 - INFO - Convert - Train & Dev set: Doc: 817 | Event: 22894 | Cluster: 14794 18 | 2022/04/09 04:50:45 - INFO - Convert - finding the correct split... 19 | 2022/04/09 05:09:33 - INFO - Convert - Done! 20 | 2022/04/09 05:09:33 - INFO - Convert - saving train set ... 21 | 2022/04/09 05:09:33 - INFO - Utils - KBP train - Doc: 735 | Event: 20509 | Cluster: 13292 | Singleton: 10067 22 | 2022/04/09 05:09:33 - INFO - Convert - Finished! 23 | 2022/04/09 05:09:33 - INFO - Convert - saving dev set ... 24 | 2022/04/09 05:09:33 - INFO - Utils - KBP dev - Doc: 82 | Event: 2382 | Cluster: 1502 | Singleton: 1126 25 | 2022/04/09 05:09:33 - INFO - Convert - Finished! 26 | 2022/04/09 05:09:33 - INFO - Convert - saving test set ... 27 | 2022/04/09 05:09:33 - INFO - Utils - KBP test - Doc: 167 | Event: 4375 | Cluster: 2963 | Singleton: 2358 28 | 2022/04/09 05:09:33 - INFO - Convert - Finished! 29 | 30 | 2022/04/09 05:09:33 - INFO - Filter - KBP train event filtered: 1629 (same 1621 / overlapping 8) 31 | 2022/04/09 05:09:33 - INFO - Filter - KBP train cluster filtered: 951 32 | 2022/04/09 05:09:33 - INFO - Utils - KBP train - Doc: 735 | Event: 18880 | Cluster: 12341 | Singleton: 9369 33 | 2022/04/09 05:09:33 - INFO - Convert - saving filtered train set ... 34 | 2022/04/09 05:09:33 - INFO - Convert - Finished! 35 | 2022/04/09 05:09:33 - INFO - Filter - KBP dev event filtered: 200 (same 198 / overlapping 2) 36 | 2022/04/09 05:09:33 - INFO - Filter - KBP dev cluster filtered: 100 37 | 2022/04/09 05:09:33 - INFO - Utils - KBP dev - Doc: 82 | Event: 2182 | Cluster: 1402 | Singleton: 1051 38 | 2022/04/09 05:09:33 - INFO - Convert - saving filtered dev set ... 39 | 2022/04/09 05:09:33 - INFO - Convert - Finished! 40 | 2022/04/09 05:09:33 - INFO - Filter - KBP test event filtered: 379 (same 378 / overlapping 1) 41 | 2022/04/09 05:09:33 - INFO - Filter - KBP test cluster filtered: 256 42 | 2022/04/09 05:09:33 - INFO - Utils - KBP test - Doc: 167 | Event: 3996 | Cluster: 2707 | Singleton: 2161 43 | 2022/04/09 05:09:33 - INFO - Convert - saving filtered test set ... 44 | 2022/04/09 05:09:33 - INFO - Convert - Finished! -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 4 | datefmt='%Y/%m/%d %H:%M:%S', 5 | level=logging.INFO) 6 | 7 | logger = logging.getLogger("Utils") 8 | filter_logger = logging.getLogger("Filter") 9 | 10 | def print_data_statistic(doc_list, dataset=''): 11 | doc_num = len(doc_list) 12 | event_num = sum([len(doc['events']) for doc in doc_list]) 13 | cluster_num = sum([len(doc['clusters']) for doc in doc_list]) 14 | singleton_num = sum([1 if len(cluster['events']) == 1 else 0 15 | for doc in doc_list for cluster in doc['clusters']]) 16 | logger.info(f"KBP {dataset} - Doc: {doc_num} | Event: {event_num} | Cluster: {cluster_num} | Singleton: {singleton_num}") 17 | 18 | def check_event_conflict(doc_list): 19 | for doc in doc_list: 20 | event_list = doc['events'] 21 | event_list.sort(key=lambda x:x['start']) 22 | if len(event_list) < 2: 23 | continue 24 | for idx in range(len(event_list)-1): 25 | if ( 26 | ( 27 | event_list[idx]['start'] == event_list[idx+1]['start'] and 28 | event_list[idx]['trigger'] == event_list[idx+1]['trigger'] 29 | ) or 30 | ( 31 | event_list[idx]['start'] + len(event_list[idx]['trigger']) > 32 | event_list[idx+1]['start'] 33 | ) 34 | ): 35 | logger.error('{}: ({})[{}] VS ({})[{}]'.format(doc['doc_id'], 36 | event_list[idx]['start'], event_list[idx]['trigger'], 37 | event_list[idx+1]['start'], event_list[idx+1]['trigger'])) 38 | 39 | def filter_events(doc_list, dataset=''): 40 | same = 0 41 | overlapping = 0 42 | cluster_num_filtered = 0 43 | for doc in doc_list: 44 | event_list = doc['events'] 45 | event_list.sort(key=lambda x:x['start']) 46 | event_filtered = [] 47 | if len(event_list) < 2: 48 | continue 49 | new_event_list, should_add = [], True 50 | for idx in range(len(event_list)-1): 51 | if (event_list[idx]['start'] == event_list[idx+1]['start'] and 52 | event_list[idx]['trigger'] == event_list[idx+1]['trigger'] 53 | ): 54 | event_filtered.append(event_list[idx]['event_id']) 55 | same += 1 56 | continue 57 | if (event_list[idx]['start'] + len(event_list[idx]['trigger']) > 58 | event_list[idx+1]['start'] 59 | ): 60 | overlapping += 1 61 | if len(event_list[idx]['trigger']) < len(event_list[idx+1]['trigger']): 62 | new_event_list.append(event_list[idx]) 63 | should_add = False 64 | else: 65 | event_filtered.append(event_list[idx]['event_id']) 66 | continue 67 | if should_add: 68 | new_event_list.append(event_list[idx]) 69 | else: 70 | event_filtered.append(event_list[idx]['event_id']) 71 | should_add = True 72 | if should_add: 73 | new_event_list.append(event_list[-1]) 74 | doc['events'] = new_event_list 75 | new_clusters = [] 76 | for cluster in doc['clusters']: 77 | new_events = [event_id for event_id in cluster['events'] if event_id not in event_filtered] 78 | if len(new_events) == 0: 79 | cluster_num_filtered += 1 80 | continue 81 | new_clusters.append({ 82 | 'hopper_id': cluster['hopper_id'], 83 | 'events': new_events 84 | }) 85 | doc['clusters'] = new_clusters 86 | filter_logger.info(f'KBP {dataset} event filtered: {same + overlapping} (same {same} / overlapping {overlapping})') 87 | filter_logger.info(f'KBP {dataset} cluster filtered: {cluster_num_filtered}') 88 | return doc_list -------------------------------------------------------------------------------- /download_pt_models.sh: -------------------------------------------------------------------------------- 1 | mkdir -p ../PT_MODELS/bert-base-cased/ 2 | wget -P ../PT_MODELS/bert-base-cased/ https://huggingface.co/bert-base-cased/resolve/main/pytorch_model.bin 3 | wget -P ../PT_MODELS/bert-base-cased/ https://huggingface.co/bert-base-cased/resolve/main/README.md 4 | wget -P ../PT_MODELS/bert-base-cased/ https://huggingface.co/bert-base-cased/resolve/main/config.json 5 | wget -P ../PT_MODELS/bert-base-cased/ https://huggingface.co/bert-base-cased/resolve/main/tokenizer.json 6 | wget -P ../PT_MODELS/bert-base-cased/ https://huggingface.co/bert-base-cased/resolve/main/tokenizer_config.json 7 | wget -P ../PT_MODELS/bert-base-cased/ https://huggingface.co/bert-base-cased/resolve/main/vocab.txt 8 | mkdir -p ../PT_MODELS/bert-large-cased/ 9 | wget -P ../PT_MODELS/bert-large-cased/ https://huggingface.co/bert-large-cased/resolve/main/pytorch_model.bin 10 | wget -P ../PT_MODELS/bert-large-cased/ https://huggingface.co/bert-large-cased/resolve/main/README.md 11 | wget -P ../PT_MODELS/bert-large-cased/ https://huggingface.co/bert-large-cased/resolve/main/config.json 12 | wget -P ../PT_MODELS/bert-large-cased/ https://huggingface.co/bert-large-cased/resolve/main/tokenizer.json 13 | wget -P ../PT_MODELS/bert-large-cased/ https://huggingface.co/bert-large-cased/resolve/main/tokenizer_config.json 14 | wget -P ../PT_MODELS/bert-large-cased/ https://huggingface.co/bert-large-cased/resolve/main/vocab.txt 15 | mkdir -p ../PT_MODELS/roberta-base/ 16 | wget -P ../PT_MODELS/roberta-base/ https://huggingface.co/roberta-base/resolve/main/pytorch_model.bin 17 | wget -P ../PT_MODELS/roberta-base/ https://huggingface.co/roberta-base/resolve/main/README.md 18 | wget -P ../PT_MODELS/roberta-base/ https://huggingface.co/roberta-base/resolve/main/config.json 19 | wget -P ../PT_MODELS/roberta-base/ https://huggingface.co/roberta-base/resolve/main/dict.txt 20 | wget -P ../PT_MODELS/roberta-base/ https://huggingface.co/roberta-base/resolve/main/merges.txt 21 | wget -P ../PT_MODELS/roberta-base/ https://huggingface.co/roberta-base/resolve/main/tokenizer.json 22 | wget -P ../PT_MODELS/roberta-base/ https://huggingface.co/roberta-base/resolve/main/vocab.json 23 | mkdir -p ../PT_MODELS/roberta-large/ 24 | wget -P ../PT_MODELS/roberta-large/ https://huggingface.co/roberta-large/resolve/main/pytorch_model.bin 25 | wget -P ../PT_MODELS/roberta-large/ https://huggingface.co/roberta-large/resolve/main/README.md 26 | wget -P ../PT_MODELS/roberta-large/ https://huggingface.co/roberta-large/resolve/main/config.json 27 | wget -P ../PT_MODELS/roberta-large/ https://huggingface.co/roberta-large/resolve/main/merges.txt 28 | wget -P ../PT_MODELS/roberta-large/ https://huggingface.co/roberta-large/resolve/main/tokenizer.json 29 | wget -P ../PT_MODELS/roberta-large/ https://huggingface.co/roberta-large/resolve/main/vocab.json 30 | mkdir -p ../PT_MODELS/SpanBERT/spanbert-base-cased 31 | wget -P ../PT_MODELS/SpanBERT/spanbert-base-cased https://huggingface.co/SpanBERT/spanbert-base-cased/resolve/main/pytorch_model.bin 32 | wget -P ../PT_MODELS/SpanBERT/spanbert-base-cased https://huggingface.co/SpanBERT/spanbert-base-cased/resolve/main/config.json 33 | wget -P ../PT_MODELS/SpanBERT/spanbert-base-cased https://huggingface.co/SpanBERT/spanbert-base-cased/resolve/main/vocab.txt 34 | mkdir -p ../PT_MODELS/SpanBERT/spanbert-large-cased 35 | wget -P ../PT_MODELS/SpanBERT/spanbert-large-cased https://huggingface.co/SpanBERT/spanbert-large-cased/resolve/main/pytorch_model.bin 36 | wget -P ../PT_MODELS/SpanBERT/spanbert-large-cased https://huggingface.co/SpanBERT/spanbert-large-cased/resolve/main/config.json 37 | wget -P ../PT_MODELS/SpanBERT/spanbert-large-cased https://huggingface.co/SpanBERT/spanbert-large-cased/resolve/main/vocab.txt 38 | mkdir -p ../PT_MODELS/allenai/longformer-large-4096/ 39 | wget -P ../PT_MODELS/allenai/longformer-large-4096/ https://huggingface.co/allenai/longformer-large-4096/resolve/main/pytorch_model.bin 40 | wget -P ../PT_MODELS/allenai/longformer-large-4096/ https://huggingface.co/allenai/longformer-large-4096/resolve/main/config.json 41 | wget -P ../PT_MODELS/allenai/longformer-large-4096/ https://huggingface.co/allenai/longformer-large-4096/resolve/main/merges.txt 42 | wget -P ../PT_MODELS/allenai/longformer-large-4096/ https://huggingface.co/allenai/longformer-large-4096/resolve/main/tokenizer.json 43 | wget -P ../PT_MODELS/allenai/longformer-large-4096/ https://huggingface.co/allenai/longformer-large-4096/resolve/main/vocab.json 44 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.22.3 2 | torch==1.11.0 3 | seqeval==1.2.2 4 | scikit-learn==1.1.2 5 | allennlp==2.9.2 6 | transformers==4.17.0 7 | -------------------------------------------------------------------------------- /src/ablation/arg.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_args(): 4 | parser = argparse.ArgumentParser() 5 | 6 | # Required parameters 7 | parser.add_argument("--output_dir", default=None, type=str, required=True, 8 | help="The output directory where the model checkpoints and predictions will be written.", 9 | ) 10 | parser.add_argument("--train_file", default=None, type=str, required=True, help="The input training file.") 11 | parser.add_argument("--dev_file", default=None, type=str, required=True, help="The input evaluation file.") 12 | parser.add_argument("--test_file", default=None, type=str, required=True, help="The input testing file.") 13 | 14 | parser.add_argument("--model_type", 15 | default="longformer", type=str, required=False 16 | ) 17 | parser.add_argument("--model_checkpoint", 18 | default="allenai/longformer-base-4096", type=str, required=False, 19 | help="Path to pretrained model or model identifier from huggingface.co/models", 20 | ) 21 | parser.add_argument("--max_seq_length", default=4096, type=int, required=False) 22 | parser.add_argument("--matching_style", default="multi", type=str, required=True, 23 | help="how to match two event representations" 24 | ) 25 | 26 | parser.add_argument("--do_train", action="store_true", help="Whether to run training.") 27 | parser.add_argument("--do_test", action="store_true", help="Whether to run eval on the test set.") 28 | parser.add_argument("--do_predict", action="store_true", help="Whether to save predicted labels.") 29 | parser.add_argument("--do_analysis", action="store_true", help="Whether to do analysis on the test set.") 30 | 31 | # Other parameters 32 | parser.add_argument("--cache_dir", default=None, type=str, 33 | help="Where do you want to store the pre-trained models downloaded from s3" 34 | ) 35 | parser.add_argument("--topic_model", default='stm', type=str, 36 | choices=['stm', 'stm_bn', 'vmf'] 37 | ) 38 | parser.add_argument("--topic_dim", default=32, type=int) 39 | parser.add_argument("--topic_inter_map", default=64, type=int) 40 | parser.add_argument("--mention_encoder_type", default="bert", type=str) 41 | parser.add_argument("--mention_encoder_checkpoint", 42 | default="bert-large-cased", type=str, 43 | help="Path to pretrained model or model identifier from huggingface.co/models", 44 | ) 45 | parser.add_argument("--include_mention_context", action="store_true") 46 | parser.add_argument("--max_mention_length", default=512, type=int) 47 | parser.add_argument("--add_contrastive_loss", action="store_true") 48 | parser.add_argument("--softmax_loss", default='ce', type=str, 49 | help="The loss function for softmax model.", 50 | choices=['lsr', 'focal', 'ce'] 51 | ) 52 | 53 | parser.add_argument("--learning_rate", default=1e-5, type=float, help="The initial learning rate for Adam.") 54 | parser.add_argument("--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform.") 55 | parser.add_argument("--batch_size", default=4, type=int) 56 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") 57 | 58 | parser.add_argument("--adam_beta1", default=0.9, type=float, 59 | help="Epsilon for Adam optimizer." 60 | ) 61 | parser.add_argument("--adam_beta2", default=0.98, type=float, 62 | help="Epsilon for Adam optimizer." 63 | ) 64 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 65 | help="Epsilon for Adam optimizer." 66 | ) 67 | parser.add_argument("--warmup_proportion", default=0.1, type=float, 68 | help="Proportion of training to perform linear learning rate warmup for,E.g., 0.1 = 10% of training." 69 | ) 70 | parser.add_argument("--weight_decay", default=0.01, type=float, 71 | help="Weight decay if we apply some." 72 | ) 73 | args = parser.parse_args() 74 | return args -------------------------------------------------------------------------------- /src/ablation/chunk_global_encoder.sh: -------------------------------------------------------------------------------- 1 | export OUTPUT_DIR=./ChunkBertEncoder_M-multi-cosine_results/ 2 | 3 | python3 chunk_global_encoder.py \ 4 | --output_dir=$OUTPUT_DIR \ 5 | --model_type=bert \ 6 | --model_checkpoint=../../../PT_MODELS/bert-large-cased/ \ 7 | --mention_encoder_type=bert \ 8 | --mention_encoder_checkpoint=../../../PT_MODELS/bert-base-cased/ \ 9 | --topic_model=vmf \ 10 | --topic_dim=32 \ 11 | --topic_inter_map=64 \ 12 | --train_file=../../data/train_filtered.json \ 13 | --dev_file=../../data/dev_filtered.json \ 14 | --test_file=../../data/test_filtered.json \ 15 | --max_seq_length=512 \ 16 | --max_mention_length=256 \ 17 | --learning_rate=1e-5 \ 18 | --matching_style=multi_cosine \ 19 | --softmax_loss=ce \ 20 | --num_train_epochs=50 \ 21 | --batch_size=1 \ 22 | --do_train \ 23 | --warmup_proportion=0. \ 24 | --seed=42 -------------------------------------------------------------------------------- /src/ablation/without_global_encoder.sh: -------------------------------------------------------------------------------- 1 | export OUTPUT_DIR=./NoGlobal_M-multi-cosine_results/ 2 | 3 | python3 without_global_encoder.py \ 4 | --output_dir=$OUTPUT_DIR \ 5 | --mention_encoder_type=bert \ 6 | --mention_encoder_checkpoint=../../../PT_MODELS/bert-base-cased/ \ 7 | --topic_model=vmf \ 8 | --topic_dim=32 \ 9 | --topic_inter_map=64 \ 10 | --train_file=../../data/train_filtered.json \ 11 | --dev_file=../../data/dev_filtered.json \ 12 | --test_file=../../data/test_filtered.json \ 13 | --max_mention_length=256 \ 14 | --learning_rate=1e-5 \ 15 | --matching_style=multi_cosine \ 16 | --softmax_loss=ce \ 17 | --num_train_epochs=50 \ 18 | --batch_size=1 \ 19 | --do_train \ 20 | --warmup_proportion=0. \ 21 | --seed=42 -------------------------------------------------------------------------------- /src/analysis/run_analysis.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import classification_report 2 | from collections import defaultdict 3 | import sys 4 | sys.path.append('../../') 5 | from src.analysis.utils import get_event_pair_set 6 | 7 | gold_coref_file = '../../data/test.json' 8 | pred_coref_file = 'MaskTopicBN_M-multi-cosine.json' 9 | 10 | def all_metrics(gold_coref_file, pred_coref_file): 11 | gold_coref_results, pred_coref_results = get_event_pair_set(gold_coref_file, pred_coref_file) 12 | all_event_pairs = [] # (gold_coref, pred_coref) 13 | for doc_id, gold_coref_result_dict in gold_coref_results.items(): 14 | # {e_i_start-e_j_start: (coref, sent_dist, e_i_link_len, e_j_link_len)} 15 | gold_unrecognized_event_pairs, gold_recognized_event_pairs = ( 16 | gold_coref_result_dict['unrecognized_event_pairs'], 17 | gold_coref_result_dict['recognized_event_pairs'] 18 | ) 19 | pred_coref_result_dict = pred_coref_results[doc_id] 20 | # {e_i_start-e_j_start: (coref, sent_dist, e_i_link_len, e_j_link_len)} 21 | pred_recognized_event_pairs, pred_wrong_event_pairs = ( 22 | pred_coref_result_dict['recognized_event_pairs'], 23 | pred_coref_result_dict['wrong_event_pairs'] 24 | ) 25 | for pair_results in gold_unrecognized_event_pairs.values(): 26 | all_event_pairs.append([str(pair_results[0]), '2']) 27 | for pair_id, pair_results in gold_recognized_event_pairs.items(): 28 | all_event_pairs.append([str(pair_results[0]), str(pred_recognized_event_pairs[pair_id][0])]) 29 | for pair_id, pair_results in pred_wrong_event_pairs.items(): 30 | all_event_pairs.append(['0', str(pair_results[0])]) 31 | y_true, y_pred = [res[0] for res in all_event_pairs], [res[1] for res in all_event_pairs] 32 | metrics = {'ALL': classification_report(y_true=y_true, y_pred=y_pred, output_dict=True)['1']} 33 | return metrics 34 | 35 | def different_distance_metrics(gold_coref_file, pred_coref_file, adj_distance=3): 36 | gold_coref_results, pred_coref_results = get_event_pair_set(gold_coref_file, pred_coref_file) 37 | same_event_pairs, adj_event_pairs, far_event_pairs = [], [], [] 38 | for doc_id, gold_coref_result_dict in gold_coref_results.items(): 39 | # {e_i_start-e_j_start: (coref, sent_dist, e_i_link_len, e_j_link_len)} 40 | gold_unrecognized_event_pairs, gold_recognized_event_pairs = ( 41 | gold_coref_result_dict['unrecognized_event_pairs'], 42 | gold_coref_result_dict['recognized_event_pairs'] 43 | ) 44 | pred_coref_result_dict = pred_coref_results[doc_id] 45 | # {e_i_start-e_j_start: (coref, sent_dist, e_i_link_len, e_j_link_len)} 46 | pred_recognized_event_pairs, pred_wrong_event_pairs = ( 47 | pred_coref_result_dict['recognized_event_pairs'], 48 | pred_coref_result_dict['wrong_event_pairs'] 49 | ) 50 | for pair_results in gold_unrecognized_event_pairs.values(): 51 | sent_dist = pair_results[1] 52 | pair_coref = [str(pair_results[0]), '2'] 53 | if sent_dist == 0: # same sentence 54 | same_event_pairs.append(pair_coref) 55 | elif sent_dist < adj_distance: # adjacent sentence 56 | adj_event_pairs.append(pair_coref) 57 | else: # far sentence 58 | far_event_pairs.append(pair_coref) 59 | for pair_id, pair_results in gold_recognized_event_pairs.items(): 60 | sent_dist = pair_results[1] 61 | pair_coref = [str(pair_results[0]), str(pred_recognized_event_pairs[pair_id][0])] 62 | if sent_dist == 0: # same sentence 63 | same_event_pairs.append(pair_coref) 64 | elif sent_dist < adj_distance: # adjacent sentence 65 | adj_event_pairs.append(pair_coref) 66 | else: # far sentence 67 | far_event_pairs.append(pair_coref) 68 | for pair_id, pair_results in pred_wrong_event_pairs.items(): 69 | sent_dist = pair_results[1] 70 | pair_coref = ['0', str(pair_results[0])] 71 | if sent_dist == 0: # same sentence 72 | same_event_pairs.append(pair_coref) 73 | elif sent_dist < adj_distance: # adjacent sentence 74 | adj_event_pairs.append(pair_coref) 75 | else: # far sentence 76 | far_event_pairs.append(pair_coref) 77 | metrics = {} 78 | y_true, y_pred = [res[0] for res in same_event_pairs], [res[1] for res in same_event_pairs] 79 | metrics['SAME'] = classification_report(y_true=y_true, y_pred=y_pred, output_dict=True)['1'] 80 | y_true, y_pred = [res[0] for res in adj_event_pairs], [res[1] for res in adj_event_pairs] 81 | metrics['ADJ'] = classification_report(y_true=y_true, y_pred=y_pred, output_dict=True)['1'] 82 | y_true, y_pred = [res[0] for res in far_event_pairs], [res[1] for res in far_event_pairs] 83 | metrics['FAR'] = classification_report(y_true=y_true, y_pred=y_pred, output_dict=True)['1'] 84 | return metrics 85 | 86 | def main_link_metrics(gold_coref_file, pred_coref_file, main_link_length=5, mode='ge'): 87 | assert mode in ['g', 'ge', 'e', 'le', 'l'] 88 | gold_coref_results, pred_coref_results = get_event_pair_set(gold_coref_file, pred_coref_file) 89 | main_link_event_pairs, singleton_event_pairs = [], defaultdict(list) 90 | for doc_id, gold_coref_result_dict in gold_coref_results.items(): 91 | # {e_i_start-e_j_start: (coref, sent_dist, e_i_link_len, e_j_link_len)} 92 | gold_unrecognized_event_pairs, gold_recognized_event_pairs = ( 93 | gold_coref_result_dict['unrecognized_event_pairs'], 94 | gold_coref_result_dict['recognized_event_pairs'] 95 | ) 96 | pred_coref_result_dict = pred_coref_results[doc_id] 97 | # {e_i_start-e_j_start: (coref, sent_dist, e_i_link_len, e_j_link_len)} 98 | pred_recognized_event_pairs, pred_wrong_event_pairs = ( 99 | pred_coref_result_dict['recognized_event_pairs'], 100 | pred_coref_result_dict['wrong_event_pairs'] 101 | ) 102 | 103 | for pair_id, pair_results in gold_recognized_event_pairs.items(): 104 | e_starts = pair_id.split('-') 105 | e_i_link_len, e_j_link_len = pair_results[2], pair_results[3] 106 | pair_coref = [str(pair_results[0]), str(pred_recognized_event_pairs[pair_id][0])] 107 | if e_i_link_len == 1: 108 | singleton_event_pairs[e_starts[0]].append(pair_coref[0] == pair_coref[1]) 109 | if e_j_link_len == 1: 110 | singleton_event_pairs[e_starts[1]].append(pair_coref[0] == pair_coref[1]) 111 | if mode == 'g': 112 | if e_i_link_len > main_link_length or e_j_link_len > main_link_length: 113 | main_link_event_pairs.append(pair_coref) 114 | elif mode == 'ge': 115 | if e_i_link_len >= main_link_length or e_j_link_len >= main_link_length: 116 | main_link_event_pairs.append(pair_coref) 117 | elif mode == 'e': 118 | if e_i_link_len == main_link_length or e_j_link_len == main_link_length: 119 | main_link_event_pairs.append(pair_coref) 120 | elif mode == 'le': 121 | if (e_i_link_len <= main_link_length and e_i_link_len > 1) or (e_j_link_len <= main_link_length and e_j_link_len > 1): 122 | main_link_event_pairs.append(pair_coref) 123 | elif mode == 'l': 124 | if (e_i_link_len < main_link_length and e_i_link_len > 1) or (e_j_link_len < main_link_length and e_j_link_len > 1): 125 | main_link_event_pairs.append(pair_coref) 126 | for pair_id, pair_results in gold_unrecognized_event_pairs.items(): 127 | e_starts = pair_id.split('-') 128 | e_i_link_len, e_j_link_len = pair_results[2], pair_results[3] 129 | pair_coref = [str(pair_results[0]), '2'] 130 | if e_i_link_len == 1 and e_starts[0] not in singleton_event_pairs: 131 | singleton_event_pairs[e_starts[0]].append(False) 132 | if e_j_link_len == 1 and e_starts[1] not in singleton_event_pairs: 133 | singleton_event_pairs[e_starts[1]].append(False) 134 | if mode == 'g': 135 | if e_i_link_len > main_link_length or e_j_link_len > main_link_length: 136 | main_link_event_pairs.append(pair_coref) 137 | elif mode == 'ge': 138 | if e_i_link_len >= main_link_length or e_j_link_len >= main_link_length: 139 | main_link_event_pairs.append(pair_coref) 140 | elif mode == 'e': 141 | if e_i_link_len == main_link_length or e_j_link_len == main_link_length: 142 | main_link_event_pairs.append(pair_coref) 143 | elif mode == 'le': 144 | if (e_i_link_len <= main_link_length and e_i_link_len > 1) or (e_j_link_len <= main_link_length and e_j_link_len > 1): 145 | main_link_event_pairs.append(pair_coref) 146 | elif mode == 'l': 147 | if (e_i_link_len < main_link_length and e_i_link_len > 1) or (e_j_link_len < main_link_length and e_j_link_len > 1): 148 | main_link_event_pairs.append(pair_coref) 149 | 150 | for pair_id, pair_results in pred_wrong_event_pairs.items(): 151 | e_starts = pair_id.split('-') 152 | if e_starts[0] in singleton_event_pairs: 153 | singleton_event_pairs[e_starts[0]].append(pair_results[0] == 0) 154 | if e_starts[1] in singleton_event_pairs: 155 | singleton_event_pairs[e_starts[1]].append(pair_results[0] == 0) 156 | 157 | mode_str = {'g': '>', 'ge': '>=', 'e': '==', 'le': '<=', 'l': '<'}[mode] 158 | metrics = {} 159 | y_true, y_pred = [res[0] for res in main_link_event_pairs], [res[1] for res in main_link_event_pairs] 160 | metrics[f'Main Link ({mode_str}{main_link_length})'] = classification_report(y_true=y_true, y_pred=y_pred, output_dict=True)['1'] 161 | wrong_num = sum([False in singleton_coref_correct for singleton_coref_correct in singleton_event_pairs.values()]) 162 | print(wrong_num) 163 | print(len(singleton_event_pairs)) 164 | metrics['Singleton Acc'] = (len(singleton_event_pairs) - wrong_num) / len(singleton_event_pairs) * 100 165 | return metrics 166 | 167 | 168 | # print(all_metrics(gold_coref_file, pred_coref_file)) 169 | # print(different_distance_metrics(gold_coref_file, pred_coref_file)) 170 | print(main_link_metrics(gold_coref_file, pred_coref_file, main_link_length=10, mode='ge')) 171 | -------------------------------------------------------------------------------- /src/analysis/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import namedtuple, defaultdict 4 | 5 | Sentence = namedtuple("Sentence", ["start", "text"]) 6 | kbp_sent_dic = defaultdict(list) # {filename: [Sentence]} 7 | with open(os.path.join('../../data/kbp_sent.txt'), 'rt', encoding='utf-8') as sents: 8 | for line in sents: 9 | doc_id, start, text = line.strip().split('\t') 10 | kbp_sent_dic[doc_id].append(Sentence(int(start), text)) 11 | 12 | def get_event_sent_idx(e_start, e_end, sents): 13 | for sent_idx, sent in enumerate(sents): 14 | sent_end = sent.start + len(sent.text) - 1 15 | if e_start >= sent.start and e_end <= sent_end: 16 | return sent_idx 17 | return None 18 | 19 | def get_gold_corefs(gold_test_file): 20 | 21 | def _get_event_cluster_id_and_link_len(event_id, clusters): 22 | for cluster in clusters: 23 | if event_id in cluster['events']: 24 | return cluster['hopper_id'], len(cluster['events']) 25 | return None, None 26 | 27 | gold_dict = {} 28 | with open(gold_test_file, 'rt', encoding='utf-8') as f: 29 | for line in f: 30 | sample = json.loads(line.strip()) 31 | clusters = sample['clusters'] 32 | events = sample['events'] 33 | event_pairs = {} # {e_i_start-e_j_start: (coref, sent_dist, e_i_link_len, e_j_link_len)} 34 | for i in range(len(events) - 1): 35 | e_i_start = events[i]['start'] 36 | e_i_cluster_id, e_i_link_len = _get_event_cluster_id_and_link_len(events[i]['event_id'], clusters) 37 | assert e_i_cluster_id is not None 38 | e_i_sent_idx = events[i]['sent_idx'] 39 | for j in range(i + 1, len(events)): 40 | e_j_start = events[j]['start'] 41 | e_j_cluster_id, e_j_link_len = _get_event_cluster_id_and_link_len(events[j]['event_id'], clusters) 42 | assert e_j_cluster_id is not None 43 | e_j_sent_idx = events[j]['sent_idx'] 44 | event_pairs[f'{e_i_start}-{e_j_start}'] = [ 45 | 1 if e_i_cluster_id == e_j_cluster_id else 0, abs(int(e_i_sent_idx) - int(e_j_sent_idx)), e_i_link_len, e_j_link_len 46 | ] 47 | gold_dict[sample['doc_id']] = event_pairs 48 | return gold_dict 49 | 50 | def get_pred_coref_results(pred_file_path): 51 | pred_dict = {} 52 | with open(pred_file_path, 'rt', encoding='utf-8') as f: 53 | for line in f: 54 | sample = json.loads(line.strip()) 55 | sents = kbp_sent_dic[sample['doc_id']] 56 | events = sample['events'] 57 | pred_labels = sample['pred_label'] 58 | event_pairs = {} # {e_i_start-e_j_start: (coref, sent_dist, e_i_link_len, e_j_link_len)} 59 | event_pair_idx = -1 60 | for i in range(len(events) - 1): 61 | e_i_start = events[i]['start'] 62 | e_i_sent_idx = get_event_sent_idx(events[i]['start'], events[i]['end'], sents) 63 | assert e_i_sent_idx is not None 64 | for j in range(i + 1, len(events)): 65 | event_pair_idx += 1 66 | e_j_start = events[j]['start'] 67 | e_j_sent_idx = get_event_sent_idx(events[j]['start'], events[j]['end'], sents) 68 | assert e_j_sent_idx is not None 69 | event_pairs[f'{e_i_start}-{e_j_start}'] = [pred_labels[event_pair_idx], abs(int(e_i_sent_idx) - int(e_j_sent_idx)), 0, 0] 70 | pred_dict[sample['doc_id']] = event_pairs 71 | return pred_dict 72 | 73 | def get_event_pair_set(gold_coref_file, pred_coref_file): 74 | 75 | gold_coref_results = get_gold_corefs(gold_coref_file) 76 | pred_coref_results = get_pred_coref_results(pred_coref_file) 77 | 78 | new_gold_coref_results = {} 79 | for doc_id, event_pairs in gold_coref_results.items(): 80 | pred_event_pairs = pred_coref_results[doc_id] 81 | unrecognized_event_pairs = {} 82 | recognized_event_pairs = {} 83 | for pair_id, results in event_pairs.items(): 84 | if pair_id in pred_event_pairs: 85 | recognized_event_pairs[pair_id] = results 86 | else: 87 | unrecognized_event_pairs[pair_id] = results 88 | new_gold_coref_results[doc_id] = { 89 | 'unrecognized_event_pairs': unrecognized_event_pairs, 90 | 'recognized_event_pairs': recognized_event_pairs 91 | } 92 | new_pred_coref_results = {} 93 | for doc_id, event_pairs in pred_coref_results.items(): 94 | gold_event_pairs = gold_coref_results[doc_id] 95 | recognized_event_pairs = {} 96 | wrong_event_pairs = {} 97 | for pair_id, results in event_pairs.items(): 98 | if pair_id in gold_event_pairs: 99 | recognized_event_pairs[pair_id] = results 100 | else: 101 | wrong_event_pairs[pair_id] = results 102 | new_pred_coref_results[doc_id] = { 103 | 'recognized_event_pairs': recognized_event_pairs, 104 | 'wrong_event_pairs': wrong_event_pairs 105 | } 106 | 107 | return new_gold_coref_results, new_pred_coref_results 108 | -------------------------------------------------------------------------------- /src/clustering/arg.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_args(): 4 | parser = argparse.ArgumentParser() 5 | 6 | # Required parameters 7 | parser.add_argument("--output_dir", default=None, type=str, required=True, 8 | help="The output directory where the conll files and evaluate results will be written.", 9 | ) 10 | parser.add_argument("--test_golden_filepath", default=None, type=str, required=True, 11 | help="golden test set file path.", 12 | ) 13 | parser.add_argument("--test_pred_filepath", default=None, type=str, required=True, 14 | help="predicted coref file path.", 15 | ) 16 | parser.add_argument("--golden_conll_filename", default=None, type=str, required=True) 17 | parser.add_argument("--pred_conll_filename", default=None, type=str, required=True) 18 | 19 | # Other parameters 20 | parser.add_argument("--do_rescore", action="store_true", help="Whether to rescoring coref value.") 21 | parser.add_argument("--rescore_reward", default=0.8, type=float, required=False) 22 | parser.add_argument("--rescore_penalty", default=0.8, type=float, required=False) 23 | parser.add_argument("--do_evaluate", action="store_true", help="Whether to evaluate conll files.") 24 | 25 | args = parser.parse_args() 26 | return args 27 | -------------------------------------------------------------------------------- /src/clustering/cluster.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict, defaultdict 2 | 3 | def clustering_greedy(events, pred_labels:list): 4 | ''' 5 | As long as there is a pair of events coreference 6 | between any two event chains, merge them. 7 | ''' 8 | def need_merge(set_1, set_2, coref_event_pair_set): 9 | for e1 in set_1: 10 | for e2 in set_2: 11 | if f'{e1}-{e2}' in coref_event_pair_set: 12 | return True 13 | return False 14 | 15 | def find_merge_position(cluster_list, coref_event_pairs): 16 | for i in range(len(cluster_list) - 1): 17 | for j in range(i + 1, len(cluster_list)): 18 | if need_merge(cluster_list[i], cluster_list[j], coref_event_pairs): 19 | return i, j 20 | return -1, -1 21 | 22 | if len(events) > 1: 23 | assert len(pred_labels) == len(events) * (len(events) - 1) / 2 24 | event_pairs = [ 25 | str(events[i]['start']) + '-' + str(events[j]['start']) 26 | for i in range(len(events) - 1) for j in range(i + 1, len(events)) 27 | ] 28 | coref_event_pairs = [event_pair for event_pair, pred in zip(event_pairs, pred_labels) if pred == 1] 29 | cluster_list = [] 30 | for event in events: # init each link as an event 31 | cluster_list.append(set([event['start']])) 32 | while True: 33 | i, j = find_merge_position(cluster_list, coref_event_pairs) 34 | if i == -1: # no cluster can be merged 35 | break 36 | cluster_list[i] |= cluster_list[j] 37 | del cluster_list[j] 38 | return cluster_list 39 | 40 | def clustering_rescore(events, pred_labels:list, reward=0.8, penalty=0.8): 41 | event_pairs = [ 42 | str(events[i]['start']) + '-' + str(events[j]['start']) 43 | for i in range(len(events) - 1) for j in range(i + 1, len(events)) 44 | ] 45 | coref_event_pairs = [event_pair for event_pair, pred in zip(event_pairs, pred_labels) if pred == 1] 46 | coref = OrderedDict([(event_pair, 1 if pred == 1 else -1) for event_pair, pred in zip(event_pairs, pred_labels)]) 47 | for i in range(len(events) - 1): 48 | for j in range(i + 1, len(events)): 49 | for k in range(len(events)): 50 | if k == i or k == j: 51 | continue 52 | event_i, event_j, event_k = events[i]['start'], events[j]['start'], events[k]['start'] 53 | coref_i_k = (f'{event_k}-{event_i}' if k < i else f'{event_i}-{event_k}') in coref_event_pairs 54 | coref_j_k = (f'{event_k}-{event_j}' if k < j else f'{event_j}-{event_k}') in coref_event_pairs 55 | if coref_i_k and coref_j_k: 56 | coref[f'{event_i}-{event_j}'] += reward 57 | elif coref_i_k != coref_j_k: 58 | coref[f'{event_i}-{event_j}'] -= penalty 59 | coref = OrderedDict([(event_pair, score) for event_pair, score in coref.items() if score > 0]) 60 | sorted_coref = sorted(coref.items(), key=lambda x:x[1], reverse=True) 61 | cluster_id = 0 62 | events_cluster_ids = {str(event['start']):-1 for event in events} # {event:cluster_id} 63 | for event_pair, _ in sorted_coref: 64 | e_i, e_j = event_pair.split('-') 65 | if events_cluster_ids[e_i] == events_cluster_ids[e_j] == -1: 66 | events_cluster_ids[e_i] = events_cluster_ids[e_j] = cluster_id 67 | cluster_id += 1 68 | elif events_cluster_ids[e_i] == -1: 69 | events_cluster_ids[e_i] = events_cluster_ids[e_j] 70 | elif events_cluster_ids[e_j] == -1: 71 | events_cluster_ids[e_j] = events_cluster_ids[e_j] 72 | for event, c_id in events_cluster_ids.items(): 73 | if c_id == -1: 74 | events_cluster_ids[event] = cluster_id 75 | cluster_id += 1 76 | cluster_list = defaultdict(set) 77 | for event, c_id in events_cluster_ids.items(): 78 | cluster_list[c_id].add(event) 79 | return [v for v in cluster_list.values()] 80 | 81 | def clustering(events, pred_labels:list, mode='rescore', rescore_reward=0.8, rescore_penalty=0.8): 82 | assert mode in ['greedy', 'rescore'] 83 | if mode == 'rescore': 84 | return clustering_rescore(events, pred_labels, rescore_reward, rescore_penalty) 85 | elif mode == 'greedy': 86 | return clustering_greedy(events, pred_labels) 87 | -------------------------------------------------------------------------------- /src/clustering/run_cluster.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import logging 4 | from tqdm.auto import tqdm 5 | import json 6 | import subprocess 7 | import re 8 | import sys 9 | sys.path.append('../../') 10 | from src.clustering.arg import parse_args 11 | from src.clustering.utils import create_golden_conll_file, get_pred_coref_results, create_pred_conll_file 12 | from src.clustering.cluster import clustering 13 | 14 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 15 | datefmt='%Y/%m/%d %H:%M:%S', 16 | level=logging.INFO) 17 | logger = logging.getLogger("Cluster") 18 | 19 | COREF_RESULTS_REGEX = re.compile(r".*Coreference: Recall: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\tPrecision: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\tF1: ([0-9.]+)%.*", re.DOTALL) 20 | BLANC_RESULTS_REGEX = re.compile(r".*BLANC: Recall: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\tPrecision: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\tF1: ([0-9.]+)%.*", re.DOTALL) 21 | 22 | def official_conll_eval(gold_path, predicted_path, metric, official_stdout=True): 23 | assert metric in ["muc", "bcub", "ceafe", "blanc"] 24 | cmd = ["../../reference-coreference-scorers/scorer.pl", metric, gold_path, predicted_path, "none"] 25 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE) 26 | stdout, stderr = process.communicate() 27 | process.wait() 28 | 29 | stdout = stdout.decode("utf-8") 30 | if stderr is not None: 31 | logger.error(stderr) 32 | 33 | if official_stdout: 34 | logger.info("Official result for {}".format(metric)) 35 | logger.info(stdout) 36 | 37 | coref_results_match = re.match( 38 | BLANC_RESULTS_REGEX if metric == 'blanc' else COREF_RESULTS_REGEX, 39 | stdout 40 | ) 41 | recall = float(coref_results_match.group(1)) 42 | precision = float(coref_results_match.group(2)) 43 | f1 = float(coref_results_match.group(3)) 44 | return {"r": recall, "p": precision, "f": f1} 45 | 46 | if __name__ == '__main__': 47 | args = parse_args() 48 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir): 49 | raise ValueError( 50 | f'Output directory ({args.output_dir}) already exists and is not empty.') 51 | if not os.path.exists(args.output_dir): 52 | os.mkdir(args.output_dir) 53 | golden_conll_path = os.path.join(args.output_dir, args.golden_conll_filename) 54 | pred_conll_path = os.path.join(args.output_dir, args.pred_conll_filename) 55 | 56 | logger.info(f'creating golden conll file in {args.output_dir} ...') 57 | create_golden_conll_file(args.test_golden_filepath, golden_conll_path) 58 | # clustering 59 | # {doc_id: {'events': event_list, 'pred_labels': pred_coref_labels}} 60 | pred_coref_results = get_pred_coref_results(args.test_pred_filepath) 61 | cluster_dict = {} # {doc_id: [cluster_set_1, cluster_set_2, ...]} 62 | logger.info('clustering ...') 63 | for doc_id, pred_result in tqdm(pred_coref_results.items()): 64 | cluster_list = clustering( 65 | pred_result['events'], 66 | pred_result['pred_labels'], 67 | mode='rescore' if args.do_rescore else 'greedy', 68 | rescore_reward=args.rescore_reward, 69 | rescore_penalty=args.rescore_penalty 70 | ) 71 | cluster_dict[doc_id] = cluster_list 72 | logger.info(f'saving predicted clusters in {args.output_dir} ...') 73 | create_pred_conll_file(cluster_dict, golden_conll_path, pred_conll_path) 74 | with open(os.path.join(args.output_dir, 'args.txt'), 'wt') as f: 75 | f.write(str(args)) 76 | # evaluate on the conll files 77 | if args.do_evaluate: 78 | results = { 79 | m: official_conll_eval(golden_conll_path, pred_conll_path, m, official_stdout=True) 80 | for m in ("muc", "bcub", "ceafe", "blanc") 81 | } 82 | results['avg_f1'] = sum([scores['f'] for scores in results.values()]) / len(results) 83 | logger.info(results) 84 | with open(os.path.join(args.output_dir, 'evaluate_results.json'), 'wt', encoding='utf-8') as f: 85 | f.write(json.dumps(results) + '\n') 86 | shutil.rmtree(args.output_dir) 87 | -------------------------------------------------------------------------------- /src/clustering/run_cluster.sh: -------------------------------------------------------------------------------- 1 | export OUTPUT_DIR=./TEMP/ 2 | 3 | python3 run_cluster.py \ 4 | --output_dir=$OUTPUT_DIR \ 5 | --test_golden_filepath=../../data/test.json \ 6 | --test_pred_filepath=../../data/XXX_weights.bin_test_pred_corefs.json \ 7 | --golden_conll_filename=gold_test.conll \ 8 | --pred_conll_filename=pred_test.conll \ 9 | --do_evaluate \ 10 | # --do_rescore \ 11 | # --rescore_reward=0.5 \ 12 | # --rescore_penalty=0.5 -------------------------------------------------------------------------------- /src/clustering/utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import json 3 | 4 | def get_pred_coref_results(pred_file_path, ): 5 | pred_results = {} # {doc_id: {'events': event_list, 'pred_labels': pred_coref_labels}} 6 | with open(pred_file_path, 'rt', encoding='utf-8') as f: 7 | for line in f: 8 | sample = json.loads(line.strip()) 9 | pred_results[sample['doc_id']] = { 10 | 'events': sample['events'], 11 | 'pred_labels': sample['pred_label'] 12 | } 13 | return pred_results 14 | 15 | def create_golden_conll_file(test_file_path, conll_file_path): 16 | 17 | def get_event_cluster_idx(event_id:str, clusters): 18 | for idx, cluster in enumerate(clusters): 19 | if event_id in cluster['events']: 20 | return idx 21 | print('ERROR!') 22 | return None 23 | 24 | with open(test_file_path, 'rt', encoding='utf-8') as f_in, \ 25 | open(conll_file_path, 'wt', encoding='utf-8') as f_out: 26 | for line in f_in: 27 | sample = json.loads(line.strip()) 28 | doc_id = sample['doc_id'] 29 | f_out.write(f'#begin document ({doc_id});\n') 30 | clusters = sample['clusters'] 31 | for event in sample['events']: 32 | cluster_idx = get_event_cluster_idx(event['event_id'], clusters) 33 | start = event['start'] 34 | f_out.write(f'{doc_id}\t{start}\txxx\t({cluster_idx})\n') 35 | f_out.write('#end document\n') 36 | 37 | def create_pred_conll_file(cluster_dict:dict, golden_conll_filepath:str, conll_filepath:str, no_repeat=True): 38 | ''' 39 | # Args: 40 | - cluster_dict: {doc_id: [cluster_set_1, cluster_set_2, ...]} 41 | ''' 42 | new_cluster_dict = {} # {doc_id: {event: cluster_idx}} 43 | for doc_id, cluster_list in cluster_dict.items(): 44 | event_cluster_idx = {} # {event: cluster_idx} 45 | for c_idx, cluster in enumerate(cluster_list): 46 | for event in cluster: 47 | event_cluster_idx[str(event)] = c_idx 48 | new_cluster_dict[doc_id] = event_cluster_idx 49 | golden_file_dic = collections.OrderedDict() # {doc_id: [event_1, event_2, ...]} 50 | with open(golden_conll_filepath, 'rt', encoding='utf-8') as f_in: 51 | for line in f_in: 52 | if line.startswith('#begin'): 53 | doc_id = line.replace('#begin document (', '').replace(');', '').strip() 54 | golden_file_dic[doc_id] = [] 55 | elif line.startswith('#end document'): 56 | continue 57 | else: 58 | _, event, _, _ = line.strip().split('\t') 59 | golden_file_dic[doc_id].append(event) 60 | with open(conll_filepath, 'wt', encoding='utf-8') as f_out: 61 | for doc_id, event_list in golden_file_dic.items(): 62 | event_cluster_idx = new_cluster_dict[doc_id] 63 | f_out.write('#begin document (' + doc_id + ');\n') 64 | if no_repeat: 65 | finish_events = set() 66 | for event in event_list: 67 | if event in event_cluster_idx and event not in finish_events: 68 | cluster_idx = event_cluster_idx[event] 69 | f_out.write(f'{doc_id}\t{event}\txxx\t({cluster_idx})\n') 70 | else: 71 | f_out.write(f'{doc_id}\tnull\tnull\tnull\n') 72 | finish_events.add(event) 73 | else: 74 | for event in event_list: 75 | if event in event_cluster_idx: 76 | cluster_idx = event_cluster_idx[event] 77 | f_out.write(f'{doc_id}\t{event}\txxx\t({cluster_idx})\n') 78 | else: 79 | f_out.write(f'{doc_id}\tnull\tnull\tnull\n') 80 | for event, cluster_idx in event_cluster_idx.items(): 81 | if event in event_list: 82 | continue 83 | f_out.write(f'{doc_id}\t{event}\txxx\t({cluster_idx})\n') 84 | f_out.write('#end document\n') 85 | -------------------------------------------------------------------------------- /src/global_event_coref/analysis.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | WRONG_TYPE = { 4 | 0: 'recognize_non-coref_as_coref', 5 | 1: 'recognize_coref_as_non-coref' 6 | } 7 | 8 | def get_pretty_event(sentences, sent_idx, sent_start, trigger, context=1): 9 | before = ' '.join([sent['text'] for sent in sentences[max(0, sent_idx-context):sent_idx]]).strip() 10 | after = ' '.join([sent['text'] for sent in sentences[sent_idx+1:min(len(sentences), sent_idx+context+1)]]).strip() 11 | event_mention = sentences[sent_idx]['text'] 12 | sent = event_mention[:sent_start] + '#####' + trigger + '#####' + event_mention[sent_start + len(trigger):] 13 | return before + ' ' + sent + ' ' + after 14 | 15 | def find_event_by_start(events, offset): 16 | for event in events: 17 | if event['start'] == offset: 18 | return event 19 | return None 20 | 21 | def get_coref_answer(clusters, e1_id, e2_id): 22 | for cluster in clusters: 23 | events = cluster['events'] 24 | if e1_id in events and e2_id in events: 25 | return 1 26 | elif e1_id in events or e2_id in events: 27 | return 0 28 | return 0 29 | 30 | def get_wrong_samples(doc_id, new_events, predictions, source_events, clusters, sentences, pred_event_filepath): 31 | wrong_1_list, wrong_2_list = [], [] 32 | 33 | pred_event_dict = {} 34 | with open(pred_event_filepath, 'rt' , encoding='utf-8') as f_in: 35 | for line in f_in.readlines(): 36 | sample = json.loads(line.strip()) 37 | pred_event_dict[sample['doc_id']] = [event['start'] for event in sample['pred_label']] 38 | 39 | idx = 0 40 | true_labels = [] 41 | for i in range(len(new_events) - 1): 42 | for j in range(i + 1, len(new_events)): 43 | e1_start, e2_start = new_events[i][0], new_events[j][0] 44 | if e1_start not in pred_event_dict[doc_id] or e2_start not in pred_event_dict[doc_id]: 45 | idx += 1 46 | continue 47 | e1 = find_event_by_start(source_events, e1_start) 48 | e2 = find_event_by_start(source_events, e2_start) 49 | pred_coref = predictions[idx] 50 | idx += 1 51 | true_coref = get_coref_answer(clusters, e1['event_id'], e2['event_id']) 52 | true_labels.append(true_coref) 53 | if pred_coref == true_coref: 54 | continue 55 | pretty_e1 = get_pretty_event(sentences, e1['sent_idx'], e1['sent_start'], e1['trigger']) 56 | pretty_e2 = get_pretty_event(sentences, e2['sent_idx'], e2['sent_start'], e2['trigger']) 57 | if pred_coref == 1: 58 | wrong_1_list.append({ 59 | 'doc_id': doc_id, 60 | 'e1_start': e1_start, 61 | 'e2_start': e2_start, 62 | 'e1_info': pretty_e1, 63 | 'e2_info': pretty_e2, 64 | 'wrong_type': 0 65 | }) 66 | else: 67 | wrong_2_list.append({ 68 | 'doc_id': doc_id, 69 | 'e1_start': e1_start, 70 | 'e2_start': e2_start, 71 | 'e1_info': pretty_e1, 72 | 'e2_info': pretty_e2, 73 | 'wrong_type': 1 74 | }) 75 | return wrong_1_list, wrong_2_list, true_labels -------------------------------------------------------------------------------- /src/global_event_coref/arg.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_args(): 4 | parser = argparse.ArgumentParser() 5 | 6 | # Required parameters 7 | parser.add_argument("--output_dir", default=None, type=str, required=True, 8 | help="The output directory where the model checkpoints and predictions will be written.", 9 | ) 10 | parser.add_argument("--train_file", default=None, type=str, required=True, help="The input training file.") 11 | parser.add_argument("--dev_file", default=None, type=str, required=True, help="The input evaluation file.") 12 | parser.add_argument("--test_file", default=None, type=str, required=True, help="The input testing file.") 13 | 14 | parser.add_argument("--model_type", 15 | default="longformer", type=str, required=True 16 | ) 17 | parser.add_argument("--model_checkpoint", 18 | default="allenai/longformer-base-4096", type=str, required=True, 19 | help="Path to pretrained model or model identifier from huggingface.co/models", 20 | ) 21 | parser.add_argument("--max_seq_length", default=4096, type=int, required=True) 22 | parser.add_argument("--matching_style", default="multi", type=str, required=True, 23 | help="how to match two event representations" 24 | ) 25 | 26 | parser.add_argument("--do_train", action="store_true", help="Whether to run training.") 27 | parser.add_argument("--do_test", action="store_true", help="Whether to run eval on the test set.") 28 | parser.add_argument("--do_predict", action="store_true", help="Whether to save predicted labels.") 29 | parser.add_argument("--do_analysis", action="store_true", help="Whether to do analysis on the test set.") 30 | 31 | # Other parameters 32 | parser.add_argument("--cache_dir", default=None, type=str, 33 | help="Where do you want to store the pre-trained models downloaded from s3" 34 | ) 35 | parser.add_argument("--topic_model", default='stm', type=str, 36 | choices=['stm', 'stm_bn', 'vmf'] 37 | ) 38 | parser.add_argument("--topic_dim", default=32, type=int) 39 | parser.add_argument("--topic_inter_map", default=64, type=int) 40 | parser.add_argument("--mention_encoder_type", default="bert", type=str) 41 | parser.add_argument("--mention_encoder_checkpoint", 42 | default="bert-large-cased", type=str, 43 | help="Path to pretrained model or model identifier from huggingface.co/models", 44 | ) 45 | parser.add_argument("--include_mention_context", action="store_true") 46 | parser.add_argument("--max_mention_length", default=512, type=int) 47 | parser.add_argument("--add_contrastive_loss", action="store_true") 48 | parser.add_argument("--softmax_loss", default='ce', type=str, 49 | help="The loss function for softmax model.", 50 | choices=['lsr', 'focal', 'ce'] 51 | ) 52 | 53 | parser.add_argument("--learning_rate", default=1e-5, type=float, help="The initial learning rate for Adam.") 54 | parser.add_argument("--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform.") 55 | parser.add_argument("--batch_size", default=4, type=int) 56 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") 57 | 58 | parser.add_argument("--adam_beta1", default=0.9, type=float, 59 | help="Epsilon for Adam optimizer." 60 | ) 61 | parser.add_argument("--adam_beta2", default=0.98, type=float, 62 | help="Epsilon for Adam optimizer." 63 | ) 64 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 65 | help="Epsilon for Adam optimizer." 66 | ) 67 | parser.add_argument("--warmup_proportion", default=0.1, type=float, 68 | help="Proportion of training to perform linear learning rate warmup for,E.g., 0.1 = 10% of training." 69 | ) 70 | parser.add_argument("--weight_decay", default=0.01, type=float, 71 | help="Weight decay if we apply some." 72 | ) 73 | args = parser.parse_args() 74 | return args -------------------------------------------------------------------------------- /src/global_event_coref/run_global_base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import torch 4 | import json 5 | from tqdm.auto import tqdm 6 | from transformers import AutoConfig, AutoTokenizer 7 | from transformers import AdamW, get_scheduler 8 | import numpy as np 9 | from sklearn.metrics import classification_report 10 | import sys 11 | sys.path.append('../../') 12 | from src.tools import seed_everything, NpEncoder 13 | from src.global_event_coref.arg import parse_args 14 | from src.global_event_coref.data import KBPCoref, get_dataLoader 15 | from src.global_event_coref.modeling import LongformerSoftmaxForEC 16 | from src.global_event_coref.analysis import get_wrong_samples, WRONG_TYPE 17 | 18 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 19 | datefmt='%Y/%m/%d %H:%M:%S', 20 | level=logging.INFO) 21 | logger = logging.getLogger("Model") 22 | 23 | def to_device(args, batch_data): 24 | new_batch_data = {} 25 | for k, v in batch_data.items(): 26 | if k in ['batch_events', 'batch_event_cluster_ids']: 27 | new_batch_data[k] = v 28 | elif k == 'batch_inputs': 29 | new_batch_data[k] = { 30 | k_: v_.to(args.device) for k_, v_ in v.items() 31 | } 32 | else: 33 | raise ValueError(f'Unknown batch data key: {k}') 34 | return new_batch_data 35 | 36 | def train_loop(args, dataloader, model, optimizer, lr_scheduler, epoch, total_loss): 37 | progress_bar = tqdm(range(len(dataloader))) 38 | progress_bar.set_description(f'loss: {0:>7f}') 39 | finish_step_num = epoch * len(dataloader) 40 | 41 | model.train() 42 | for step, batch_data in enumerate(dataloader, start=1): 43 | batch_data = to_device(args, batch_data) 44 | outputs = model(**batch_data) 45 | loss = outputs[0] 46 | 47 | if loss: 48 | optimizer.zero_grad() 49 | loss.backward() 50 | optimizer.step() 51 | lr_scheduler.step() 52 | 53 | total_loss += loss.item() if loss else 0. 54 | progress_bar.set_description(f'loss: {total_loss/(finish_step_num + step):>7f}') 55 | progress_bar.update(1) 56 | return total_loss 57 | 58 | def test_loop(args, dataloader, model): 59 | true_labels, true_predictions = [], [] 60 | model.eval() 61 | with torch.no_grad(): 62 | for batch_data in tqdm(dataloader): 63 | batch_data = to_device(args, batch_data) 64 | outputs = model(**batch_data) 65 | _, logits, masks, labels = outputs 66 | 67 | predictions = logits.argmax(dim=-1).cpu().numpy() # [batch, event_pair_num] 68 | y = labels.cpu().numpy() 69 | lens = np.sum(masks.cpu().numpy(), axis=-1) 70 | true_labels += [ 71 | int(l) for label, seq_len in zip(y, lens) for idx, l in enumerate(label) if idx < seq_len 72 | ] 73 | true_predictions += [ 74 | int(p) for pred, seq_len in zip(predictions, lens) for idx, p in enumerate(pred) if idx < seq_len 75 | ] 76 | return classification_report(true_labels, true_predictions, output_dict=True) 77 | 78 | def train(args, train_dataset, dev_dataset, model, tokenizer): 79 | """ Train the model """ 80 | train_dataloader = get_dataLoader(args, train_dataset, tokenizer, shuffle=True) 81 | dev_dataloader = get_dataLoader(args, dev_dataset, tokenizer, shuffle=False) 82 | t_total = len(train_dataloader) * args.num_train_epochs 83 | # Prepare optimizer and schedule (linear warmup and decay) 84 | no_decay = ["bias", "LayerNorm.weight"] 85 | optimizer_grouped_parameters = [ 86 | {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay}, 87 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0} 88 | ] 89 | args.warmup_steps = int(t_total * args.warmup_proportion) 90 | optimizer = AdamW( 91 | optimizer_grouped_parameters, 92 | lr=args.learning_rate, 93 | betas=(args.adam_beta1, args.adam_beta2), 94 | eps=args.adam_epsilon 95 | ) 96 | lr_scheduler = get_scheduler( 97 | 'linear', 98 | optimizer, 99 | num_warmup_steps=args.warmup_steps, 100 | num_training_steps=t_total 101 | ) 102 | # Train! 103 | logger.info("***** Running training *****") 104 | logger.info(f"Num examples - {len(train_dataset)}") 105 | logger.info(f"Num Epochs - {args.num_train_epochs}") 106 | logger.info(f"Total optimization steps - {t_total}") 107 | with open(os.path.join(args.output_dir, 'args.txt'), 'wt') as f: 108 | f.write(str(args)) 109 | 110 | total_loss = 0. 111 | best_f1 = 0. 112 | for epoch in range(args.num_train_epochs): 113 | print(f"Epoch {epoch+1}/{args.num_train_epochs}\n-------------------------------") 114 | total_loss = train_loop(args, train_dataloader, model, optimizer, lr_scheduler, epoch, total_loss) 115 | metrics = test_loop(args, dev_dataloader, model) 116 | dev_p, dev_r, dev_f1 = metrics['1']['precision'], metrics['1']['recall'], metrics['1']['f1-score'] 117 | logger.info(f'Dev: P - {(100*dev_p):0.4f} R - {(100*dev_r):0.4f} F1 - {(100*dev_f1):0.4f}') 118 | if dev_f1 > best_f1: 119 | best_f1 = dev_f1 120 | logger.info(f'saving new weights to {args.output_dir}...\n') 121 | save_weight = f'epoch_{epoch+1}_dev_f1_{(100*dev_f1):0.4f}_weights.bin' 122 | torch.save(model.state_dict(), os.path.join(args.output_dir, save_weight)) 123 | elif 100 * dev_p > 69 and 100 * dev_r > 69: 124 | logger.info(f'saving new weights to {args.output_dir}...\n') 125 | save_weight = f'epoch_{epoch+1}_dev_f1_{(100*dev_f1):0.4f}_weights.bin' 126 | torch.save(model.state_dict(), os.path.join(args.output_dir, save_weight)) 127 | with open(os.path.join(args.output_dir, 'dev_metrics.txt'), 'at') as f: 128 | f.write(f'epoch_{epoch+1}\n' + json.dumps(metrics, cls=NpEncoder) + '\n\n') 129 | logger.info("Done!") 130 | 131 | def predict(args, document:str, events:list, model, tokenizer): 132 | ''' 133 | # Args: 134 | - events: [ 135 | [e_char_start, e_char_end], ... 136 | ], document[e1_char_start:e1_char_end + 1] = trigger1 137 | ''' 138 | inputs = tokenizer( 139 | document, 140 | max_length=args.max_seq_length, 141 | truncation=True, 142 | return_tensors="pt" 143 | ) 144 | filtered_events = [] 145 | new_events = [] 146 | for event in events: 147 | char_start, char_end = event 148 | token_start = inputs.char_to_token(char_start) 149 | if not token_start: 150 | token_start = inputs.char_to_token(char_start + 1) 151 | token_end = inputs.char_to_token(char_end) 152 | if not token_start or not token_end: 153 | continue 154 | filtered_events.append([token_start, token_end]) 155 | new_events.append(event) 156 | if not new_events: 157 | return [], [], [] 158 | inputs = { 159 | 'batch_inputs': inputs, 160 | 'batch_events': [filtered_events] 161 | } 162 | inputs = to_device(args, inputs) 163 | with torch.no_grad(): 164 | outputs = model(**inputs) 165 | logits = outputs[1] 166 | predictions = logits.argmax(dim=-1)[0].cpu().numpy().tolist() 167 | probabilities = torch.nn.functional.softmax(logits, dim=-1)[0].cpu().numpy().tolist() 168 | probabilities = [probabilities[idx][pred] for idx, pred in enumerate(predictions)] 169 | if len(new_events) > 1: 170 | assert len(predictions) == len(new_events) * (len(new_events) - 1) / 2 171 | return new_events, predictions, probabilities 172 | 173 | def test(args, test_dataset, model, tokenizer, save_weights:list): 174 | test_dataloader = get_dataLoader(args, test_dataset, tokenizer, batch_size=1, shuffle=False) 175 | logger.info('***** Running testing *****') 176 | for save_weight in save_weights: 177 | logger.info(f'loading weights from {save_weight}...') 178 | model.load_state_dict(torch.load(os.path.join(args.output_dir, save_weight))) 179 | metrics = test_loop(args, test_dataloader, model) 180 | with open(os.path.join(args.output_dir, 'test_metrics.txt'), 'at') as f: 181 | f.write(f'{save_weight}\n{json.dumps(metrics, cls=NpEncoder)}\n\n') 182 | 183 | if __name__ == '__main__': 184 | args = parse_args() 185 | if args.do_train and os.path.exists(args.output_dir) and os.listdir(args.output_dir): 186 | raise ValueError(f'Output directory ({args.output_dir}) already exists and is not empty.') 187 | if not os.path.exists(args.output_dir): 188 | os.mkdir(args.output_dir) 189 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 190 | args.n_gpu = torch.cuda.device_count() 191 | logger.warning(f'Using {args.device} device, n_gpu: {args.n_gpu}') 192 | # Set seed 193 | seed_everything(args.seed) 194 | # Load pretrained model and tokenizer 195 | logger.info(f'using model {"with" if args.add_contrastive_loss else "without"} Contrastive loss') 196 | logger.info(f'loading pretrained model and tokenizer of {args.model_type} ...') 197 | config = AutoConfig.from_pretrained(args.model_checkpoint, cache_dir=args.cache_dir, ) 198 | tokenizer = AutoTokenizer.from_pretrained(args.model_checkpoint, cache_dir=args.cache_dir) 199 | args.num_labels = 2 200 | model = LongformerSoftmaxForEC.from_pretrained( 201 | args.model_checkpoint, 202 | config=config, 203 | cache_dir=args.cache_dir, 204 | args=args 205 | ).to(args.device) 206 | # Training 207 | if args.do_train: 208 | train_dataset = KBPCoref(args.train_file) 209 | dev_dataset = KBPCoref(args.dev_file) 210 | train(args, train_dataset, dev_dataset, model, tokenizer) 211 | # Testing 212 | save_weights = [file for file in os.listdir(args.output_dir) if file.endswith('.bin')] 213 | if args.do_test: 214 | test_dataset = KBPCoref(args.test_file) 215 | test(args, test_dataset, model, tokenizer, save_weights) 216 | # Predicting 217 | if args.do_predict: 218 | pred_event_file = 'epoch_3_dev_f1_57.9994_weights.bin_test_pred_events.json' 219 | # pred_event_file = 'test_filtered.json' 220 | 221 | for best_save_weight in save_weights: 222 | logger.info(f'loading weights from {best_save_weight}...') 223 | model.load_state_dict(torch.load(os.path.join(args.output_dir, best_save_weight))) 224 | logger.info(f'predicting coref labels of {best_save_weight}...') 225 | results = [] 226 | model.eval() 227 | with open(os.path.join(args.output_dir, pred_event_file), 'rt' , encoding='utf-8') as f_in: 228 | for line in tqdm(f_in.readlines()): 229 | sample = json.loads(line.strip()) 230 | events_from_file = sample['events'] if pred_event_file == 'test_filtered.json' else sample['pred_label'] 231 | events = [ 232 | [event['start'], event['start'] + len(event['trigger']) - 1] 233 | for event in events_from_file 234 | ] 235 | new_events, predictions, probabilities = predict(args, sample['document'], events, model, tokenizer) 236 | results.append({ 237 | "doc_id": sample['doc_id'], 238 | "document": sample['document'], 239 | "events": [ 240 | { 241 | 'start': char_start, 242 | 'end': char_end, 243 | 'trigger': sample['document'][char_start:char_end+1] 244 | } for char_start, char_end in new_events 245 | ], 246 | "pred_label": predictions, 247 | "pred_prob": probabilities 248 | }) 249 | save_name = '_gold_test_pred_corefs.json' if pred_event_file == 'test_filtered.json' else '_test_pred_corefs.json' 250 | with open(os.path.join(args.output_dir, best_save_weight + save_name), 'wt', encoding='utf-8') as f: 251 | for exapmle_result in results: 252 | f.write(json.dumps(exapmle_result) + '\n') 253 | # Analysis 254 | if args.do_analysis: 255 | pred_event_file = 'epoch_3_dev_f1_57.9994_weights.bin_test_pred_events.json' 256 | pred_event_filepath = os.path.join(args.output_dir, pred_event_file) 257 | 258 | analysis_weight = 'XXX_weights.bin' 259 | logger.info(f'loading weights from {analysis_weight}...') 260 | model.load_state_dict(torch.load(os.path.join(args.output_dir, analysis_weight))) 261 | logger.info(f'predicting coref labels of {analysis_weight}...') 262 | all_wrong_1, all_wrong_2 = [], [] 263 | all_predictions, all_labels = [], [] 264 | model.eval() 265 | with open(os.path.join(args.test_file), 'rt' , encoding='utf-8') as f_in: 266 | for line in tqdm(f_in.readlines()): 267 | sample = json.loads(line.strip()) 268 | events = [ 269 | [event['start'], event['start'] + len(event['trigger']) - 1] 270 | for event in sample['events'] 271 | ] 272 | new_events, predictions, _ = predict(args, sample['document'], events, model, tokenizer) 273 | all_predictions += predictions 274 | wrong_1_list, wrong_2_list, true_labels = get_wrong_samples( 275 | sample['doc_id'], 276 | new_events, predictions, 277 | sample['events'], sample['clusters'], sample['sentences'], 278 | pred_event_filepath 279 | ) 280 | all_labels += true_labels 281 | all_wrong_1 += wrong_1_list 282 | all_wrong_2 += wrong_2_list 283 | assert len(all_labels) == len(all_predictions) 284 | print(classification_report(all_labels, all_predictions)) 285 | print(f'all_wrong_1: {len(all_wrong_1)}\tall_wrong_2: {len(all_wrong_2)}') 286 | with open(os.path.join(args.output_dir, analysis_weight + '_' + WRONG_TYPE[0] + '.json'), 'wt', encoding='utf-8') as f_out_1: 287 | f_out_1.write(json.dumps(all_wrong_1) + '\n') 288 | with open(os.path.join(args.output_dir, analysis_weight + '_' + WRONG_TYPE[1] + '.json'), 'wt', encoding='utf-8') as f_out_2: 289 | f_out_2.write(json.dumps(all_wrong_2) + '\n') 290 | -------------------------------------------------------------------------------- /src/global_event_coref/run_global_base.sh: -------------------------------------------------------------------------------- 1 | export OUTPUT_DIR=./M-multi-cosine_results/ 2 | 3 | python3 run_global_base.py \ 4 | --output_dir=$OUTPUT_DIR \ 5 | --model_type=longformer \ 6 | --model_checkpoint=../../../PT_MODELS/allenai/longformer-large-4096/ \ 7 | --train_file=../../data/train_filtered.json \ 8 | --dev_file=../../data/dev_filtered.json \ 9 | --test_file=../../data/test_filtered.json \ 10 | --max_seq_length=4096 \ 11 | --learning_rate=1e-5 \ 12 | --add_contrastive_loss \ 13 | --matching_style=multi_cosine \ 14 | --softmax_loss=ce \ 15 | --num_train_epochs=30 \ 16 | --batch_size=1 \ 17 | --do_train \ 18 | --warmup_proportion=0. \ 19 | --seed=42 -------------------------------------------------------------------------------- /src/global_event_coref/run_global_base_with_mask.sh: -------------------------------------------------------------------------------- 1 | export OUTPUT_DIR=./Mask_M-multi-cosine_closs_results/ 2 | 3 | python3 run_global_base_with_mask.py \ 4 | --output_dir=$OUTPUT_DIR \ 5 | --model_type=longformer \ 6 | --model_checkpoint=../../../PT_MODELS/allenai/longformer-large-4096/ \ 7 | --mention_encoder_type=bert \ 8 | --mention_encoder_checkpoint=../../../PT_MODELS/bert-base-cased/ \ 9 | --train_file=../../data/train_filtered.json \ 10 | --dev_file=../../data/dev_filtered.json \ 11 | --test_file=../../data/test_filtered.json \ 12 | --max_seq_length=4096 \ 13 | --max_mention_length=256 \ 14 | --learning_rate=1e-5 \ 15 | --add_contrastive_loss \ 16 | --matching_style=multi_cosine \ 17 | --softmax_loss=ce \ 18 | --num_train_epochs=50 \ 19 | --batch_size=1 \ 20 | --do_train \ 21 | --warmup_proportion=0. \ 22 | --seed=42 -------------------------------------------------------------------------------- /src/global_event_coref/run_global_base_with_mask_topic.sh: -------------------------------------------------------------------------------- 1 | export OUTPUT_DIR=./MaskTopic_M-multi-cosine_closs_results/ 2 | 3 | python3 run_global_base_with_mask_topic.py \ 4 | --output_dir=$OUTPUT_DIR \ 5 | --model_type=longformer \ 6 | --model_checkpoint=../../../PT_MODELS/allenai/longformer-large-4096/ \ 7 | --mention_encoder_type=bert \ 8 | --mention_encoder_checkpoint=../../../PT_MODELS/bert-base-cased/ \ 9 | --topic_model=vmf \ 10 | --topic_dim=32 \ 11 | --topic_inter_map=64 \ 12 | --train_file=../../data/train_filtered.json \ 13 | --dev_file=../../data/dev_filtered.json \ 14 | --test_file=../../data/test_filtered.json \ 15 | --max_seq_length=4096 \ 16 | --max_mention_length=256 \ 17 | --learning_rate=1e-5 \ 18 | --add_contrastive_loss \ 19 | --matching_style=multi_cosine \ 20 | --softmax_loss=ce \ 21 | --num_train_epochs=50 \ 22 | --batch_size=1 \ 23 | --do_train \ 24 | --warmup_proportion=0. \ 25 | --seed=42 -------------------------------------------------------------------------------- /src/global_event_coref/run_global_base_with_topic.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from tqdm.auto import tqdm 3 | import json 4 | from collections import defaultdict, namedtuple 5 | import torch 6 | from transformers import AdamW, get_scheduler 7 | from transformers import AutoConfig, AutoTokenizer 8 | import numpy as np 9 | from sklearn.metrics import classification_report 10 | import os 11 | import sys 12 | sys.path.append('../../') 13 | from src.tools import seed_everything, NpEncoder 14 | from src.global_event_coref.arg import parse_args 15 | from src.global_event_coref.data import KBPCoref, get_dataLoader, vocab, VOCAB_SIZE 16 | from src.global_event_coref.modeling import LongformerSoftmaxForECwithTopic 17 | 18 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 19 | datefmt='%Y/%m/%d %H:%M:%S', 20 | level=logging.INFO) 21 | logger = logging.getLogger("Model") 22 | Sentence = namedtuple("Sentence", ["start", "text"]) 23 | 24 | def to_device(args, batch_data): 25 | new_batch_data = {} 26 | for k, v in batch_data.items(): 27 | if k in ['batch_events', 'batch_event_cluster_ids']: 28 | new_batch_data[k] = v 29 | elif k == 'batch_event_dists': 30 | new_batch_data[k] = [ 31 | torch.tensor(event_dists, dtype=torch.float32).to(args.device) 32 | for event_dists in v 33 | ] 34 | elif k == 'batch_inputs': 35 | new_batch_data[k] = { 36 | k_: v_.to(args.device) for k_, v_ in v.items() 37 | } 38 | else: 39 | raise ValueError(f'Unknown batch data key: {k}') 40 | return new_batch_data 41 | 42 | def train_loop(args, dataloader, model, optimizer, lr_scheduler, epoch, total_loss): 43 | progress_bar = tqdm(range(len(dataloader))) 44 | progress_bar.set_description(f'loss: {0:>7f}') 45 | finish_step_num = epoch * len(dataloader) 46 | 47 | model.train() 48 | for step, batch_data in enumerate(dataloader, start=1): 49 | batch_data = to_device(args, batch_data) 50 | outputs = model(**batch_data) 51 | loss = outputs[0] 52 | 53 | if loss: 54 | optimizer.zero_grad() 55 | loss.backward() 56 | optimizer.step() 57 | lr_scheduler.step() 58 | 59 | total_loss += loss.item() if loss else 0. 60 | progress_bar.set_description(f'loss: {total_loss/(finish_step_num + step):>7f}') 61 | progress_bar.update(1) 62 | return total_loss 63 | 64 | def test_loop(args, dataloader, model): 65 | true_labels, true_predictions = [], [] 66 | model.eval() 67 | with torch.no_grad(): 68 | for batch_data in tqdm(dataloader): 69 | batch_data = to_device(args, batch_data) 70 | outputs = model(**batch_data) 71 | _, logits, masks, labels = outputs 72 | 73 | predictions = logits.argmax(dim=-1).cpu().numpy() # [batch, event_pair_num] 74 | y = labels.cpu().numpy() 75 | lens = np.sum(masks.cpu().numpy(), axis=-1) 76 | true_labels += [ 77 | int(l) for label, seq_len in zip(y, lens) for idx, l in enumerate(label) if idx < seq_len 78 | ] 79 | true_predictions += [ 80 | int(p) for pred, seq_len in zip(predictions, lens) for idx, p in enumerate(pred) if idx < seq_len 81 | ] 82 | return classification_report(true_labels, true_predictions, output_dict=True) 83 | 84 | def train(args, train_dataset, dev_dataset, model, tokenizer): 85 | """ Train the model """ 86 | train_dataloader = get_dataLoader(args, train_dataset, tokenizer, shuffle=True, collote_fn_type='with_dist') 87 | dev_dataloader = get_dataLoader(args, dev_dataset, tokenizer, shuffle=False, collote_fn_type='with_dist') 88 | t_total = len(train_dataloader) * args.num_train_epochs 89 | # Prepare optimizer and schedule (linear warmup and decay) 90 | no_decay = ["bias", "LayerNorm.weight"] 91 | optimizer_grouped_parameters = [ 92 | {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay}, 93 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0} 94 | ] 95 | args.warmup_steps = int(t_total * args.warmup_proportion) 96 | optimizer = AdamW( 97 | optimizer_grouped_parameters, 98 | lr=args.learning_rate, 99 | betas=(args.adam_beta1, args.adam_beta2), 100 | eps=args.adam_epsilon 101 | ) 102 | lr_scheduler = get_scheduler( 103 | 'linear', 104 | optimizer, 105 | num_warmup_steps=args.warmup_steps, 106 | num_training_steps=t_total 107 | ) 108 | # Train! 109 | logger.info("***** Running training *****") 110 | logger.info(f"Num examples - {len(train_dataset)}") 111 | logger.info(f"Num Epochs - {args.num_train_epochs}") 112 | logger.info(f"Total optimization steps - {t_total}") 113 | with open(os.path.join(args.output_dir, 'args.txt'), 'wt') as f: 114 | f.write(str(args)) 115 | 116 | total_loss = 0. 117 | best_f1 = 0. 118 | for epoch in range(args.num_train_epochs): 119 | print(f"Epoch {epoch+1}/{args.num_train_epochs}\n-------------------------------") 120 | total_loss = train_loop(args, train_dataloader, model, optimizer, lr_scheduler, epoch, total_loss) 121 | metrics = test_loop(args, dev_dataloader, model) 122 | dev_p, dev_r, dev_f1 = metrics['1']['precision'], metrics['1']['recall'], metrics['1']['f1-score'] 123 | logger.info(f'Dev: P - {(100*dev_p):0.4f} R - {(100*dev_r):0.4f} F1 - {(100*dev_f1):0.4f}') 124 | if dev_f1 > best_f1: 125 | best_f1 = dev_f1 126 | logger.info(f'saving new weights to {args.output_dir}...\n') 127 | save_weight = f'epoch_{epoch+1}_dev_f1_{(100*dev_f1):0.4f}_weights.bin' 128 | torch.save(model.state_dict(), os.path.join(args.output_dir, save_weight)) 129 | elif 100 * dev_p > 69 and 100 * dev_r > 69: 130 | logger.info(f'saving new weights to {args.output_dir}...\n') 131 | save_weight = f'epoch_{epoch+1}_dev_f1_{(100*dev_f1):0.4f}_weights.bin' 132 | torch.save(model.state_dict(), os.path.join(args.output_dir, save_weight)) 133 | with open(os.path.join(args.output_dir, 'dev_metrics.txt'), 'at') as f: 134 | f.write(f'epoch_{epoch+1}\n' + json.dumps(metrics, cls=NpEncoder) + '\n\n') 135 | logger.info("Done!") 136 | 137 | def predict(args, document:str, events:list, event_dists:list, model, tokenizer): 138 | assert len(events) == len(event_dists) 139 | inputs = tokenizer( 140 | document, 141 | max_length=args.max_seq_length, 142 | truncation=True, 143 | return_tensors="pt" 144 | ) 145 | filtered_events = [] 146 | new_events = [] 147 | filtered_dists = [] 148 | for event, event_dist in zip(events, event_dists): 149 | char_start, char_end = event 150 | token_start = inputs.char_to_token(char_start) 151 | if not token_start: 152 | token_start = inputs.char_to_token(char_start + 1) 153 | token_end = inputs.char_to_token(char_end) 154 | if not token_start or not token_end: 155 | continue 156 | filtered_events.append([token_start, token_end]) 157 | new_events.append(event) 158 | filtered_dists.append(event_dist) 159 | if not new_events: 160 | return [], [], [] 161 | inputs = { 162 | 'batch_inputs': inputs, 163 | 'batch_events': [filtered_events], 164 | 'batch_event_dists': [np.asarray(filtered_dists)] 165 | } 166 | inputs = to_device(args, inputs) 167 | with torch.no_grad(): 168 | outputs = model(**inputs) 169 | logits = outputs[1] 170 | predictions = logits.argmax(dim=-1)[0].cpu().numpy().tolist() 171 | probabilities = torch.nn.functional.softmax(logits, dim=-1)[0].cpu().numpy().tolist() 172 | probabilities = [probabilities[idx][pred] for idx, pred in enumerate(predictions)] 173 | if len(new_events) > 1: 174 | assert len(predictions) == len(new_events) * (len(new_events) - 1) / 2 175 | return new_events, predictions, probabilities 176 | 177 | def test(args, test_dataset, model, tokenizer, save_weights:list): 178 | test_dataloader = get_dataLoader( 179 | args, test_dataset, tokenizer, batch_size=1, shuffle=False, 180 | collote_fn_type='with_dist' 181 | ) 182 | logger.info('***** Running testing *****') 183 | for save_weight in save_weights: 184 | logger.info(f'loading weights from {save_weight}...') 185 | model.load_state_dict(torch.load(os.path.join(args.output_dir, save_weight))) 186 | metrics = test_loop(args, test_dataloader, model) 187 | with open(os.path.join(args.output_dir, 'test_metrics.txt'), 'at') as f: 188 | f.write(f'{save_weight}\n{json.dumps(metrics, cls=NpEncoder)}\n\n') 189 | 190 | def get_event_dist(e_start, e_end, sents): 191 | for s_idx, sent in enumerate(sents): 192 | sent_end = sent.start + len(sent.text) - 1 193 | if e_start >= sent.start and e_end <= sent_end: 194 | before = sents[s_idx - 1].text if s_idx > 0 else '' 195 | after = sents[s_idx + 1].text if s_idx < len(sents) - 1 else '' 196 | event_mention = before + (' ' if len(before) > 0 else '') + sent.text + ' ' + after 197 | event_mention = event_mention.lower() 198 | return [1 if w in event_mention else 0 for w in vocab] 199 | return None 200 | 201 | if __name__ == '__main__': 202 | args = parse_args() 203 | if args.do_train and os.path.exists(args.output_dir) and os.listdir(args.output_dir): 204 | raise ValueError(f'Output directory ({args.output_dir}) already exists and is not empty.') 205 | if not os.path.exists(args.output_dir): 206 | os.mkdir(args.output_dir) 207 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 208 | args.n_gpu = torch.cuda.device_count() 209 | logger.warning(f'Using {args.device} device, n_gpu: {args.n_gpu}') 210 | # Set seed 211 | seed_everything(args.seed) 212 | # Load pretrained model and tokenizer 213 | logger.info(f'using model {"with" if args.add_contrastive_loss else "without"} Contrastive loss') 214 | logger.info(f'loading pretrained model and tokenizer of {args.model_type} ...') 215 | main_config = AutoConfig.from_pretrained(args.model_checkpoint, cache_dir=args.cache_dir) 216 | tokenizer = AutoTokenizer.from_pretrained(args.model_checkpoint, cache_dir=args.cache_dir) 217 | args.num_labels = 2 218 | args.dist_dim = VOCAB_SIZE 219 | model = LongformerSoftmaxForECwithTopic.from_pretrained( 220 | args.model_checkpoint, 221 | config=main_config, 222 | cache_dir=args.cache_dir, 223 | args=args 224 | ).to(args.device) 225 | # Training 226 | save_weights = [] 227 | if args.do_train: 228 | train_dataset = KBPCoref(args.train_file) 229 | dev_dataset = KBPCoref(args.dev_file) 230 | train(args, train_dataset, dev_dataset, model, tokenizer) 231 | # Testing 232 | save_weights = [file for file in os.listdir(args.output_dir) if file.endswith('.bin')] 233 | if args.do_test: 234 | test_dataset = KBPCoref(args.test_file) 235 | test(args, test_dataset, model, tokenizer, save_weights) 236 | # Predicting 237 | if args.do_predict: 238 | kbp_sent_dic = defaultdict(list) # {filename: [Sentence]} 239 | with open(os.path.join('../../data/kbp_sent.txt'), 'rt', encoding='utf-8') as sents: 240 | for line in sents: 241 | doc_id, start, text = line.strip().split('\t') 242 | kbp_sent_dic[doc_id].append(Sentence(int(start), text)) 243 | 244 | pred_event_file = 'epoch_3_dev_f1_57.9994_weights.bin_test_pred_events.json' 245 | # pred_event_file = 'test_filtered.json' 246 | 247 | for best_save_weight in save_weights: 248 | logger.info(f'loading weights from {best_save_weight}...') 249 | model.load_state_dict(torch.load(os.path.join(args.output_dir, best_save_weight))) 250 | logger.info(f'predicting coref labels of {best_save_weight}...') 251 | 252 | results = [] 253 | model.eval() 254 | with open(os.path.join(args.output_dir, pred_event_file), 'rt' , encoding='utf-8') as f_in: 255 | for line in tqdm(f_in.readlines()): 256 | sample = json.loads(line.strip()) 257 | events_from_file = sample['events'] if pred_event_file == 'test_filtered.json' else sample['pred_label'] 258 | events = [ 259 | [event['start'], event['start'] + len(event['trigger']) - 1] 260 | for event in events_from_file 261 | ] 262 | sents = kbp_sent_dic[sample['doc_id']] 263 | event_dists = [] 264 | for event in events_from_file: 265 | e_dist = get_event_dist(event['start'], event['start'] + len(event['trigger']) - 1, sents) 266 | assert e_dist is not None 267 | event_dists.append(e_dist) 268 | new_events, predictions, probabilities = predict( 269 | args, sample['document'], events, event_dists, model, tokenizer 270 | ) 271 | results.append({ 272 | "doc_id": sample['doc_id'], 273 | "document": sample['document'], 274 | "events": [ 275 | { 276 | 'start': char_start, 277 | 'end': char_end, 278 | 'trigger': sample['document'][char_start:char_end+1] 279 | } for char_start, char_end in new_events 280 | ], 281 | "pred_label": predictions, 282 | "pred_prob": probabilities 283 | }) 284 | save_name = '_gold_test_pred_corefs.json' if pred_event_file == 'test_filtered.json' else '_test_pred_corefs.json' 285 | with open(os.path.join(args.output_dir, best_save_weight + save_name), 'wt', encoding='utf-8') as f: 286 | for exapmle_result in results: 287 | f.write(json.dumps(exapmle_result) + '\n') -------------------------------------------------------------------------------- /src/global_event_coref/run_global_base_with_topic.sh: -------------------------------------------------------------------------------- 1 | export OUTPUT_DIR=./Topic_M-multi-cosine_closs_results/ 2 | 3 | python3 run_global_base_with_topic.py \ 4 | --output_dir=$OUTPUT_DIR \ 5 | --model_type=longformer \ 6 | --model_checkpoint=../../../PT_MODELS/allenai/longformer-large-4096/ \ 7 | --topic_model=vmf \ 8 | --topic_dim=32 \ 9 | --topic_inter_map=64 \ 10 | --train_file=../../data/train_filtered.json \ 11 | --dev_file=../../data/dev_filtered.json \ 12 | --test_file=../../data/test_filtered.json \ 13 | --max_seq_length=4096 \ 14 | --learning_rate=1e-5 \ 15 | --add_contrastive_loss \ 16 | --matching_style=multi_cosine \ 17 | --softmax_loss=ce \ 18 | --num_train_epochs=50 \ 19 | --batch_size=1 \ 20 | --do_train \ 21 | --warmup_proportion=0. \ 22 | --seed=42 -------------------------------------------------------------------------------- /src/joint_model/arg.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_args(): 4 | parser = argparse.ArgumentParser() 5 | 6 | # Required parameters 7 | parser.add_argument("--output_dir", default=None, type=str, required=True, 8 | help="The output directory where the model checkpoints and predictions will be written.", 9 | ) 10 | parser.add_argument("--train_file", default=None, type=str, required=True, help="The input training file.") 11 | parser.add_argument("--dev_file", default=None, type=str, required=True, help="The input evaluation file.") 12 | parser.add_argument("--test_file", default=None, type=str, required=True, help="The input testing file.") 13 | 14 | parser.add_argument("--model_type", 15 | default="longformer", type=str, required=True 16 | ) 17 | parser.add_argument("--model_checkpoint", 18 | default="allenai/longformer-base-4096", type=str, required=True, 19 | help="Path to pretrained model or model identifier from huggingface.co/models", 20 | ) 21 | parser.add_argument("--max_seq_length", default=4096, type=int, required=True) 22 | parser.add_argument("--matching_style", default="multi", type=str, required=True, 23 | help="how to match two event representations" 24 | ) 25 | 26 | parser.add_argument("--do_train", action="store_true", help="Whether to run training.") 27 | parser.add_argument("--do_test", action="store_true", help="Whether to run eval on the test set.") 28 | parser.add_argument("--do_predict", action="store_true", help="Whether to save predicted labels.") 29 | 30 | # Other parameters 31 | parser.add_argument("--cache_dir", default=None, type=str, 32 | help="Where do you want to store the pre-trained models downloaded from s3" 33 | ) 34 | parser.add_argument("--topic_model", default='stm', type=str, 35 | choices=['stm', 'stm_bn', 'vmf'] 36 | ) 37 | parser.add_argument("--topic_dim", default=32, type=int) 38 | parser.add_argument("--topic_inter_map", default=64, type=int) 39 | parser.add_argument("--mention_encoder_type", default="bert", type=str) 40 | parser.add_argument("--mention_encoder_checkpoint", 41 | default="bert-base-cased", type=str, 42 | help="Path to pretrained model or model identifier from huggingface.co/models", 43 | ) 44 | parser.add_argument("--max_mention_length", default=256, type=int) 45 | parser.add_argument("--add_contrastive_loss", action="store_true") 46 | parser.add_argument("--softmax_loss", default='ce', type=str, 47 | help="The loss function for softmax model.", 48 | choices=['lsr', 'focal', 'ce'] 49 | ) 50 | 51 | parser.add_argument("--learning_rate", default=1e-5, type=float, help="The initial learning rate for Adam.") 52 | parser.add_argument("--num_train_epochs", default=30, type=int, help="Total number of training epochs to perform.") 53 | parser.add_argument("--batch_size", default=1, type=int) 54 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") 55 | 56 | parser.add_argument("--adam_beta1", default=0.9, type=float, 57 | help="Epsilon for Adam optimizer." 58 | ) 59 | parser.add_argument("--adam_beta2", default=0.98, type=float, 60 | help="Epsilon for Adam optimizer." 61 | ) 62 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 63 | help="Epsilon for Adam optimizer." 64 | ) 65 | parser.add_argument("--warmup_proportion", default=0.1, type=float, 66 | help="Proportion of training to perform linear learning rate warmup for,E.g., 0.1 = 10% of training." 67 | ) 68 | parser.add_argument("--weight_decay", default=0.01, type=float, 69 | help="Weight decay if we apply some." 70 | ) 71 | args = parser.parse_args() 72 | return args -------------------------------------------------------------------------------- /src/joint_model/modeling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import CrossEntropyLoss 4 | from transformers import LongformerPreTrainedModel, LongformerModel 5 | from transformers import BertModel, RobertaModel 6 | from allennlp.modules.span_extractors import SelfAttentiveSpanExtractor 7 | from ..tools import LabelSmoothingCrossEntropy, FocalLoss 8 | from ..tools import SimpleTopicModel, SimpleTopicModelwithBN, SimpleTopicVMFModel 9 | 10 | MENTION_ENCODER = { 11 | 'bert': BertModel, 12 | 'roberta': RobertaModel 13 | } 14 | TOPIC_MODEL = { 15 | 'stm': SimpleTopicModel, 16 | 'stm_bn': SimpleTopicModelwithBN, 17 | 'vmf': SimpleTopicVMFModel 18 | } 19 | COSINE_SPACE_DIM = 64 20 | COSINE_SLICES = 128 21 | COSINE_FACTOR = 4 22 | 23 | class LongformerSoftmaxForEC(LongformerPreTrainedModel): 24 | def __init__(self, config, args): 25 | super().__init__(config) 26 | self.trigger_num_labels = args.trigger_num_labels 27 | self.num_labels = args.num_labels 28 | self.hidden_size = config.hidden_size 29 | self.loss_type = args.softmax_loss 30 | self.add_contrastive_loss = args.add_contrastive_loss 31 | self.use_device = args.device 32 | # encoder & pooler 33 | self.longformer = LongformerModel(config, add_pooling_layer=False) 34 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 35 | self.span_extractor = SelfAttentiveSpanExtractor(input_dim=self.hidden_size) 36 | self.td_classifier = nn.Linear(self.hidden_size, self.trigger_num_labels) 37 | # event matching 38 | self.matching_style = args.matching_style 39 | if 'cosine' not in self.matching_style: 40 | if self.matching_style == 'base': 41 | multiples = 2 42 | elif self.matching_style == 'multi': 43 | multiples = 3 44 | self.coref_classifier = nn.Linear(multiples * self.hidden_size, self.num_labels) 45 | else: 46 | self.cosine_space_dim, self.cosine_slices, self.tensor_factor = COSINE_SPACE_DIM, COSINE_SLICES, COSINE_FACTOR 47 | self.cosine_mat_p = nn.Parameter(torch.rand((self.tensor_factor, self.cosine_slices), requires_grad=True)) 48 | self.cosine_mat_q = nn.Parameter(torch.rand((self.tensor_factor, self.cosine_space_dim), requires_grad=True)) 49 | self.cosine_ffnn = nn.Linear(self.hidden_size, self.cosine_space_dim) 50 | if self.matching_style == 'cosine': 51 | self.coref_classifier = nn.Linear(2 * self.hidden_size + self.cosine_slices, self.num_labels) 52 | elif self.matching_style == 'multi_cosine': 53 | self.coref_classifier = nn.Linear(3 * self.hidden_size + self.cosine_slices, self.num_labels) 54 | elif self.matching_style == 'multi_dist_cosine': 55 | self.coref_classifier = nn.Linear(4 * self.hidden_size + self.cosine_slices, self.num_labels) 56 | self.post_init() 57 | 58 | def _multi_cosine(self, batch_event_1_reps, batch_event_2_reps): 59 | batch_event_1_reps = self.cosine_ffnn(batch_event_1_reps) 60 | batch_event_1_reps = batch_event_1_reps.unsqueeze(dim=2) 61 | batch_event_1_reps = self.cosine_mat_q * batch_event_1_reps 62 | batch_event_1_reps = batch_event_1_reps.permute((0, 1, 3, 2)) 63 | batch_event_1_reps = torch.matmul(batch_event_1_reps, self.cosine_mat_p) 64 | batch_event_1_reps = batch_event_1_reps.permute((0, 1, 3, 2)) 65 | # vector normalization 66 | norms_1 = (batch_event_1_reps ** 2).sum(axis=-1, keepdims=True) ** 0.5 67 | batch_event_1_reps = batch_event_1_reps / norms_1 68 | 69 | batch_event_2_reps = self.cosine_ffnn(batch_event_2_reps) 70 | batch_event_2_reps = batch_event_2_reps.unsqueeze(dim=2) 71 | batch_event_2_reps = self.cosine_mat_q * batch_event_2_reps 72 | batch_event_2_reps = batch_event_2_reps.permute((0, 1, 3, 2)) 73 | batch_event_2_reps = torch.matmul(batch_event_2_reps, self.cosine_mat_p) 74 | batch_event_2_reps = batch_event_2_reps.permute((0, 1, 3, 2)) 75 | # vector normalization 76 | norms_2 = (batch_event_2_reps ** 2).sum(axis=-1, keepdims=True) ** 0.5 77 | batch_event_2_reps = batch_event_2_reps / norms_2 78 | 79 | return torch.sum(batch_event_1_reps * batch_event_2_reps, dim=-1) 80 | 81 | def _cal_circle_loss(self, event_1_reps, event_2_reps, coref_labels, l=20.): 82 | norms_1 = (event_1_reps ** 2).sum(axis=1, keepdims=True) ** 0.5 83 | event_1_reps = event_1_reps / norms_1 84 | norms_2 = (event_2_reps ** 2).sum(axis=1, keepdims=True) ** 0.5 85 | event_2_reps = event_2_reps / norms_2 86 | event_cos = torch.sum(event_1_reps * event_2_reps, dim=1) * l 87 | # calculate the difference between each pair of Cosine values 88 | event_cos_diff = event_cos[:, None] - event_cos[None, :] 89 | # find (noncoref, coref) index 90 | select_idx = coref_labels[:, None] < coref_labels[None, :] 91 | select_idx = select_idx.float() 92 | 93 | event_cos_diff = event_cos_diff - (1 - select_idx) * 1e12 94 | event_cos_diff = event_cos_diff.view(-1) 95 | event_cos_diff = torch.cat((torch.tensor([0.0], device=self.use_device), event_cos_diff), dim=0) 96 | return torch.logsumexp(event_cos_diff, dim=0) 97 | 98 | def _matching_func(self, batch_event_1_reps, batch_event_2_reps): 99 | if self.matching_style == 'base': 100 | batch_seq_reps = torch.cat([batch_event_1_reps, batch_event_2_reps], dim=-1) 101 | elif self.matching_style == 'multi': 102 | batch_e1_e2_multi = batch_event_1_reps * batch_event_2_reps 103 | batch_seq_reps = torch.cat([batch_event_1_reps, batch_event_2_reps, batch_e1_e2_multi], dim=-1) 104 | elif self.matching_style == 'cosine': 105 | batch_multi_cosine = self._multi_cosine(batch_event_1_reps, batch_event_2_reps) 106 | batch_seq_reps = torch.cat([batch_event_1_reps, batch_event_2_reps, batch_multi_cosine], dim=-1) 107 | elif self.matching_style == 'multi_cosine': 108 | batch_e1_e2_multi = batch_event_1_reps * batch_event_2_reps 109 | batch_multi_cosine = self._multi_cosine(batch_event_1_reps, batch_event_2_reps) 110 | batch_seq_reps = torch.cat([batch_event_1_reps, batch_event_2_reps, batch_e1_e2_multi, batch_multi_cosine], dim=-1) 111 | elif self.matching_style == 'multi_dist_cosine': 112 | batch_e1_e2_multi = batch_event_1_reps * batch_event_2_reps 113 | batch_e1_e2_dist = torch.abs(batch_event_1_reps - batch_event_2_reps) 114 | batch_multi_cosine = self._multi_cosine(batch_event_1_reps, batch_event_2_reps) 115 | batch_seq_reps = torch.cat([batch_event_1_reps, batch_event_2_reps, batch_e1_e2_multi, batch_e1_e2_dist, batch_multi_cosine], dim=-1) 116 | return batch_seq_reps 117 | 118 | def forward(self, batch_inputs, batch_events=None, batch_td_labels=None, batch_event_cluster_ids=None): 119 | outputs = self.longformer(**batch_inputs) 120 | sequence_output = outputs[0] 121 | sequence_output = self.dropout(sequence_output) 122 | # predict trigger 123 | td_logits = self.td_classifier(sequence_output) 124 | if batch_events is None: 125 | return None, td_logits 126 | # construct event pairs (event_1, event_2) 127 | batch_event_1_list, batch_event_2_list = [], [] 128 | max_len, batch_event_mask = 0, [] 129 | if batch_event_cluster_ids is not None: 130 | batch_coref_labels = [] 131 | for events, event_cluster_ids in zip(batch_events, batch_event_cluster_ids): 132 | event_1_list, event_2_list, coref_labels = [], [], [] 133 | for i in range(len(events) - 1): 134 | for j in range(i + 1, len(events)): 135 | event_1_list.append(events[i]) 136 | event_2_list.append(events[j]) 137 | cluster_id_1, cluster_id_2 = event_cluster_ids[i], event_cluster_ids[j] 138 | coref_labels.append(1 if cluster_id_1 == cluster_id_2 else 0) 139 | max_len = max(max_len, len(coref_labels)) 140 | batch_event_1_list.append(event_1_list) 141 | batch_event_2_list.append(event_2_list) 142 | batch_coref_labels.append(coref_labels) 143 | batch_event_mask.append([1] * len(coref_labels)) 144 | # padding 145 | for b_idx in range(len(batch_coref_labels)): 146 | pad_length = max_len - len(batch_coref_labels[b_idx]) if max_len > 0 else 1 147 | batch_event_1_list[b_idx] += [[0, 0]] * pad_length 148 | batch_event_2_list[b_idx] += [[0, 0]] * pad_length 149 | batch_coref_labels[b_idx] += [0] * pad_length 150 | batch_event_mask[b_idx] += [0] * pad_length 151 | else: 152 | for events in batch_events: 153 | event_1_list, event_2_list = [], [] 154 | for i in range(len(events) - 1): 155 | for j in range(i + 1, len(events)): 156 | event_1_list.append(events[i]) 157 | event_2_list.append(events[j]) 158 | max_len = max(max_len, len(event_1_list)) 159 | batch_event_1_list.append(event_1_list) 160 | batch_event_2_list.append(event_2_list) 161 | batch_event_mask.append([1] * len(event_1_list)) 162 | # padding 163 | for b_idx in range(len(batch_event_mask)): 164 | pad_length = max_len - len(batch_event_mask[b_idx]) if max_len > 0 else 1 165 | batch_event_1_list[b_idx] += [[0, 0]] * pad_length 166 | batch_event_2_list[b_idx] += [[0, 0]] * pad_length 167 | batch_event_mask[b_idx] += [0] * pad_length 168 | # extract events & predict coref 169 | batch_event_1 = torch.tensor(batch_event_1_list).to(self.use_device) 170 | batch_event_2 = torch.tensor(batch_event_2_list).to(self.use_device) 171 | batch_mask = torch.tensor(batch_event_mask).to(self.use_device) 172 | batch_event_1_reps = self.span_extractor(sequence_output, batch_event_1, span_indices_mask=batch_mask) 173 | batch_event_2_reps = self.span_extractor(sequence_output, batch_event_2, span_indices_mask=batch_mask) 174 | batch_seq_reps = self._matching_func(batch_event_1_reps, batch_event_2_reps) 175 | coref_logits = self.coref_classifier(batch_seq_reps) 176 | # calculate loss 177 | loss, batch_ec_labels = None, None 178 | attention_mask = batch_inputs['attention_mask'] 179 | if batch_event_cluster_ids is not None and max_len > 0: 180 | assert self.loss_type in ['lsr', 'focal', 'ce'] 181 | if self.loss_type == 'lsr': 182 | loss_fct = LabelSmoothingCrossEntropy() 183 | elif self.loss_type == 'focal': 184 | loss_fct = FocalLoss() 185 | else: 186 | loss_fct = CrossEntropyLoss() 187 | # trigger detection loss 188 | active_td_loss = attention_mask.view(-1) == 1 189 | active_td_logits = td_logits.view(-1, self.trigger_num_labels)[active_td_loss] 190 | active_td_labels = batch_td_labels.view(-1)[active_td_loss] 191 | loss_td = loss_fct(active_td_logits, active_td_labels) 192 | # event coreference loss 193 | active_coref_loss = batch_mask.view(-1) == 1 194 | active_coref_logits = coref_logits.view(-1, self.num_labels)[active_coref_loss] 195 | batch_ec_labels = torch.tensor(batch_coref_labels).to(self.use_device) 196 | active_coref_labels = batch_ec_labels.view(-1)[active_coref_loss] 197 | loss_coref = loss_fct(active_coref_logits, active_coref_labels) 198 | if self.add_contrastive_loss: 199 | active_event_1_reps = batch_event_1_reps.view(-1, self.hidden_size)[active_coref_loss] 200 | active_event_2_reps = batch_event_2_reps.view(-1, self.hidden_size)[active_coref_loss] 201 | loss_contrasive = self._cal_circle_loss(active_event_1_reps, active_event_2_reps, active_coref_labels) 202 | loss = torch.log(1 + loss_td) + torch.log(1 + loss_coref) + 0.2 * loss_contrasive 203 | else: 204 | loss = torch.log(1 + loss_td) + torch.log(1 + loss_coref) 205 | return loss, td_logits, coref_logits, attention_mask, batch_td_labels, batch_mask, batch_ec_labels 206 | -------------------------------------------------------------------------------- /src/joint_model/run_joint_base.sh: -------------------------------------------------------------------------------- 1 | export OUTPUT_DIR=./M-multi-cosine_results/ 2 | 3 | python3 run_joint_base.py \ 4 | --output_dir=$OUTPUT_DIR \ 5 | --model_type=longformer \ 6 | --model_checkpoint=../../../PT_MODELS/allenai/longformer-large-4096/ \ 7 | --train_file=../../data/train_filtered.json \ 8 | --dev_file=../../data/dev_filtered.json \ 9 | --test_file=../../data/test_filtered.json \ 10 | --max_seq_length=4096 \ 11 | --learning_rate=1e-5 \ 12 | --add_contrastive_loss \ 13 | --matching_style=multi_cosine \ 14 | --softmax_loss=ce \ 15 | --num_train_epochs=30 \ 16 | --batch_size=1 \ 17 | --do_train \ 18 | --warmup_proportion=0. \ 19 | --seed=42 -------------------------------------------------------------------------------- /src/local_event_coref/arg.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_args(): 4 | parser = argparse.ArgumentParser() 5 | 6 | # Required parameters 7 | parser.add_argument("--output_dir", default=None, type=str, required=True, 8 | help="The output directory where the model checkpoints and predictions will be written.", 9 | ) 10 | parser.add_argument("--train_file", default=None, type=str, required=True, help="The input training file.") 11 | parser.add_argument("--dev_file", default=None, type=str, required=True, help="The input evaluation file.") 12 | parser.add_argument("--test_file", default=None, type=str, required=True, help="The input testing file.") 13 | 14 | parser.add_argument("--model_type", 15 | default="bert", type=str, required=True 16 | ) 17 | parser.add_argument("--model_checkpoint", 18 | default="bert-large-cased/", type=str, required=True, 19 | help="Path to pretrained model or model identifier from huggingface.co/models", 20 | ) 21 | parser.add_argument("--max_seq_length", default=512, type=int, required=True) 22 | parser.add_argument("--matching_style", default="multi", type=str, required=True, 23 | help="how to match two event representations" 24 | ) 25 | 26 | parser.add_argument("--do_train", action="store_true", help="Whether to run training.") 27 | parser.add_argument("--do_test", action="store_true", help="Whether to run eval on the test set.") 28 | parser.add_argument("--do_predict", action="store_true", help="Whether to save predicted labels.") 29 | 30 | # Other parameters 31 | parser.add_argument("--cache_dir", default=None, type=str, 32 | help="Where do you want to store the pre-trained models downloaded from s3" 33 | ) 34 | parser.add_argument("--topic_model", default='stm', type=str, 35 | choices=['stm', 'stm_bn', 'vmf'] 36 | ) 37 | parser.add_argument("--topic_dim", default=32, type=int) 38 | parser.add_argument("--topic_inter_map", default=64, type=int) 39 | parser.add_argument("--softmax_loss", default='ce', type=str, 40 | help="The loss function for softmax model.", 41 | choices=['lsr', 'focal', 'ce'] 42 | ) 43 | 44 | parser.add_argument("--learning_rate", default=1e-5, type=float, help="The initial learning rate for Adam.") 45 | parser.add_argument("--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform.") 46 | parser.add_argument("--batch_size", default=4, type=int) 47 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") 48 | 49 | parser.add_argument("--adam_beta1", default=0.9, type=float, 50 | help="Epsilon for Adam optimizer." 51 | ) 52 | parser.add_argument("--adam_beta2", default=0.98, type=float, 53 | help="Epsilon for Adam optimizer." 54 | ) 55 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 56 | help="Epsilon for Adam optimizer." 57 | ) 58 | parser.add_argument("--warmup_proportion", default=0.1, type=float, 59 | help="Proportion of training to perform linear learning rate warmup for,E.g., 0.1 = 10% of training." 60 | ) 61 | parser.add_argument("--weight_decay", default=0.01, type=float, 62 | help="Weight decay if we apply some." 63 | ) 64 | args = parser.parse_args() 65 | return args 66 | -------------------------------------------------------------------------------- /src/local_event_coref/run_local_base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import torch 4 | import json 5 | from collections import namedtuple, defaultdict 6 | from tqdm.auto import tqdm 7 | from transformers import AutoConfig, AutoTokenizer 8 | from transformers import AdamW, get_scheduler 9 | from sklearn.metrics import classification_report 10 | import sys 11 | sys.path.append('../../') 12 | from src.tools import seed_everything, NpEncoder 13 | from src.local_event_coref.arg import parse_args 14 | from src.local_event_coref.data import KBPCorefPair, get_dataLoader 15 | from src.local_event_coref.modeling import BertForPairwiseEC, RobertaForPairwiseEC 16 | 17 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 18 | datefmt='%Y/%m/%d %H:%M:%S', 19 | level=logging.INFO) 20 | logger = logging.getLogger("Model") 21 | Sentence = namedtuple("Sentence", ["start", "text"]) 22 | 23 | MODEL_CLASSES = { 24 | 'bert': BertForPairwiseEC, 25 | 'spanbert': BertForPairwiseEC, 26 | 'roberta': RobertaForPairwiseEC 27 | } 28 | 29 | def to_device(args, batch_data): 30 | new_batch_data = {} 31 | for k, v in batch_data.items(): 32 | if k == 'batch_inputs': 33 | new_batch_data[k] = { 34 | k_: v_.to(args.device) for k_, v_ in v.items() 35 | } 36 | else: 37 | new_batch_data[k] = torch.tensor(v).to(args.device) 38 | return new_batch_data 39 | 40 | def train_loop(args, dataloader, model, optimizer, lr_scheduler, epoch, total_loss): 41 | progress_bar = tqdm(range(len(dataloader))) 42 | progress_bar.set_description(f'loss: {0:>7f}') 43 | finish_step_num = epoch * len(dataloader) 44 | 45 | model.train() 46 | for step, batch_data in enumerate(dataloader, start=1): 47 | batch_data = to_device(args, batch_data) 48 | outputs = model(**batch_data) 49 | loss = outputs[0] 50 | 51 | optimizer.zero_grad() 52 | loss.backward() 53 | optimizer.step() 54 | lr_scheduler.step() 55 | 56 | total_loss += loss.item() 57 | progress_bar.set_description(f'loss: {total_loss/(finish_step_num + step):>7f}') 58 | progress_bar.update(1) 59 | return total_loss 60 | 61 | def test_loop(args, dataloader, model): 62 | true_labels, true_predictions = [], [] 63 | model.eval() 64 | with torch.no_grad(): 65 | for batch_data in tqdm(dataloader): 66 | batch_data = to_device(args, batch_data) 67 | outputs = model(**batch_data) 68 | logits = outputs[1] 69 | 70 | predictions = logits.argmax(dim=-1).cpu().numpy().tolist() 71 | labels = batch_data['labels'].cpu().numpy() 72 | true_predictions += predictions 73 | true_labels += [int(label) for label in labels] 74 | return classification_report(true_labels, true_predictions, output_dict=True) 75 | 76 | def train(args, train_dataset, dev_dataset, model, tokenizer): 77 | """ Train the model """ 78 | train_dataloader = get_dataLoader(args, train_dataset, tokenizer, shuffle=True) 79 | dev_dataloader = get_dataLoader(args, dev_dataset, tokenizer, shuffle=False) 80 | t_total = len(train_dataloader) * args.num_train_epochs 81 | # Prepare optimizer and schedule (linear warmup and decay) 82 | no_decay = ["bias", "LayerNorm.weight"] 83 | optimizer_grouped_parameters = [ 84 | {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay}, 85 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0} 86 | ] 87 | args.warmup_steps = int(t_total * args.warmup_proportion) 88 | optimizer = AdamW( 89 | optimizer_grouped_parameters, 90 | lr=args.learning_rate, 91 | betas=(args.adam_beta1, args.adam_beta2), 92 | eps=args.adam_epsilon 93 | ) 94 | lr_scheduler = get_scheduler( 95 | 'linear', 96 | optimizer, 97 | num_warmup_steps=args.warmup_steps, 98 | num_training_steps=t_total 99 | ) 100 | # Train! 101 | logger.info("***** Running training *****") 102 | logger.info(f"Num examples - {len(train_dataset)}") 103 | logger.info(f"Num Epochs - {args.num_train_epochs}") 104 | logger.info(f"Total optimization steps - {t_total}") 105 | with open(os.path.join(args.output_dir, 'args.txt'), 'wt') as f: 106 | f.write(str(args)) 107 | 108 | total_loss = 0. 109 | best_f1 = 0. 110 | for epoch in range(args.num_train_epochs): 111 | print(f"Epoch {epoch+1}/{args.num_train_epochs}\n-------------------------------") 112 | total_loss = train_loop(args, train_dataloader, model, optimizer, lr_scheduler, epoch, total_loss) 113 | metrics = test_loop(args, dev_dataloader, model) 114 | dev_p, dev_r, dev_f1 = metrics['1']['precision'], metrics['1']['recall'], metrics['1']['f1-score'] 115 | logger.info(f'Dev: P - {(100*dev_p):0.4f} R - {(100*dev_r):0.4f} F1 - {(100*dev_f1):0.4f}') 116 | if dev_f1 > best_f1: 117 | best_f1 = dev_f1 118 | logger.info(f'saving new weights to {args.output_dir}...\n') 119 | save_weight = f'epoch_{epoch+1}_dev_f1_{(100*dev_f1):0.4f}_weights.bin' 120 | torch.save(model.state_dict(), os.path.join(args.output_dir, save_weight)) 121 | with open(os.path.join(args.output_dir, 'dev_metrics.txt'), 'at') as f: 122 | f.write(f'epoch_{epoch+1}\n' + json.dumps(metrics, cls=NpEncoder) + '\n\n') 123 | logger.info("Done!") 124 | 125 | def test(args, test_dataset, model, tokenizer, save_weights:list): 126 | test_dataloader = get_dataLoader(args, test_dataset, tokenizer, batch_size=1, shuffle=False) 127 | logger.info('***** Running testing *****') 128 | for save_weight in save_weights: 129 | logger.info(f'loading weights from {save_weight}...') 130 | model.load_state_dict(torch.load(os.path.join(args.output_dir, save_weight))) 131 | metrics = test_loop(args, test_dataloader, model) 132 | with open(os.path.join(args.output_dir, 'test_metrics.txt'), 'at') as f: 133 | f.write(f'{save_weight}\n{json.dumps(metrics, cls=NpEncoder)}\n\n') 134 | 135 | def predict(args, sent_1, sent_2, e1_char_start, e1_char_end, e2_char_start, e2_char_end, model, tokenizer): 136 | 137 | def _cut_sent(sent, e_char_start, e_char_end, max_length): 138 | before = ' '.join([c for c in sent[:e_char_start].split(' ') if c != ''][-max_length:]).strip() 139 | trigger = sent[e_char_start:e_char_end+1] 140 | after = ' '.join([c for c in sent[e_char_end+1:].split(' ') if c != ''][:max_length]).strip() 141 | new_sent, new_char_start, new_char_end = before + ' ' + trigger + ' ' + after, len(before) + 1, len(before) + len(trigger) 142 | assert new_sent[new_char_start:new_char_end+1] == trigger 143 | return new_sent, new_char_start, new_char_end 144 | 145 | max_mention_length = (args.max_seq_length - 50) // 4 146 | sent_1, e1_char_start, e1_char_end = _cut_sent(sent_1, e1_char_start, e1_char_end, max_mention_length) 147 | sent_2, e2_char_start, e2_char_end = _cut_sent(sent_2, e2_char_start, e2_char_end, max_mention_length) 148 | inputs = tokenizer( 149 | sent_1, 150 | sent_2, 151 | max_length=args.max_seq_length, 152 | truncation=True, 153 | return_tensors="pt" 154 | ) 155 | e1_token_start = inputs.char_to_token(e1_char_start, sequence_index=0) 156 | if not e1_token_start: 157 | e1_token_start = inputs.char_to_token(e1_char_start + 1, sequence_index=0) 158 | e1_token_end = inputs.char_to_token(e1_char_end, sequence_index=0) 159 | e2_token_start = inputs.char_to_token(e2_char_start, sequence_index=1) 160 | if not e2_token_start: 161 | e2_token_start = inputs.char_to_token(e2_char_start + 1, sequence_index=1) 162 | e2_token_end = inputs.char_to_token(e2_char_end, sequence_index=1) 163 | assert e1_token_start and e1_token_end and e2_token_start and e2_token_end 164 | inputs = { 165 | 'batch_inputs': inputs, 166 | 'batch_e1_idx': [[[e1_token_start, e1_token_end]]], 167 | 'batch_e2_idx': [[[e2_token_start, e2_token_end]]] 168 | } 169 | inputs = to_device(args, inputs) 170 | with torch.no_grad(): 171 | outputs = model(**inputs) 172 | logits = outputs[1] 173 | pred = int(logits.argmax(dim=-1)[0].cpu().numpy()) 174 | prob = torch.nn.functional.softmax(logits, dim=-1)[0].cpu().numpy().tolist() 175 | return pred, prob[pred] 176 | 177 | def get_event_sent(e_start, e_end, sents): 178 | for sent in sents: 179 | sent_end = sent.start + len(sent.text) - 1 180 | if e_start >= sent.start and e_end <= sent_end: 181 | return sent.text, e_start - sent.start, e_end - sent.start 182 | return None, None, None 183 | 184 | if __name__ == '__main__': 185 | args = parse_args() 186 | if args.do_train and os.path.exists(args.output_dir) and os.listdir(args.output_dir): 187 | raise ValueError(f'Output directory ({args.output_dir}) already exists and is not empty.') 188 | if not os.path.exists(args.output_dir): 189 | os.mkdir(args.output_dir) 190 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 191 | args.n_gpu = torch.cuda.device_count() 192 | logger.warning(f'Using {args.device} device, n_gpu: {args.n_gpu}') 193 | # Set seed 194 | seed_everything(args.seed) 195 | # Load pretrained model and tokenizer 196 | logger.info(f'loading pretrained model and tokenizer of {args.model_type} ...') 197 | config = AutoConfig.from_pretrained(args.model_checkpoint, cache_dir=args.cache_dir) 198 | tokenizer = AutoTokenizer.from_pretrained(args.model_checkpoint, cache_dir=args.cache_dir) 199 | args.num_labels = 2 200 | model = MODEL_CLASSES[args.model_type].from_pretrained( 201 | args.model_checkpoint, 202 | config=config, 203 | cache_dir=args.cache_dir, 204 | args=args 205 | ).to(args.device) 206 | # Training 207 | if args.do_train: 208 | train_dataset = KBPCorefPair(args.train_file) 209 | dev_dataset = KBPCorefPair(args.dev_file) 210 | train(args, train_dataset, dev_dataset, model, tokenizer) 211 | # Testing 212 | save_weights = [file for file in os.listdir(args.output_dir) if file.endswith('.bin')] 213 | if args.do_test: 214 | test_dataset = KBPCorefPair(args.test_file) 215 | test(args, test_dataset, model, tokenizer, save_weights) 216 | # Predicting 217 | if args.do_predict: 218 | kbp_sent_dic = defaultdict(list) # {filename: [Sentence]} 219 | with open(os.path.join('../../data/kbp_sent.txt'), 'rt', encoding='utf-8') as sents: 220 | for line in sents: 221 | doc_id, start, text = line.strip().split('\t') 222 | kbp_sent_dic[doc_id].append(Sentence(int(start), text)) 223 | 224 | pred_event_file = 'epoch_3_dev_f1_57.9994_weights.bin_test_pred_events.json' 225 | 226 | for best_save_weight in save_weights: 227 | logger.info(f'loading weights from {best_save_weight}...') 228 | model.load_state_dict(torch.load(os.path.join(args.output_dir, best_save_weight))) 229 | logger.info(f'predicting coref labels of {best_save_weight}...') 230 | 231 | results = [] 232 | model.eval() 233 | pred_event_filepath = os.path.join(args.output_dir, pred_event_file) 234 | with open(pred_event_filepath, 'rt' , encoding='utf-8') as f_in: 235 | for line in tqdm(f_in.readlines()): 236 | sample = json.loads(line.strip()) 237 | events = [ 238 | (event['start'], event['start'] + len(event['trigger']) - 1, event['trigger']) 239 | for event in sample['pred_label'] 240 | ] 241 | sents = kbp_sent_dic[sample['doc_id']] 242 | new_events = [] 243 | for e_start, e_end, e_trigger in events: 244 | e_sent, e_new_start, e_new_end = get_event_sent(e_start, e_end, sents) 245 | assert e_sent is not None and e_sent[e_new_start:e_new_end+1] == e_trigger 246 | new_events.append((e_new_start, e_new_end, e_sent)) 247 | predictions, probabilities = [], [] 248 | for i in range(len(new_events) - 1): 249 | for j in range(i + 1, len(new_events)): 250 | e1_char_start, e1_char_end, sent_1 = new_events[i] 251 | e2_char_start, e2_char_end, sent_2 = new_events[j] 252 | pred, prob = predict(args, 253 | sent_1, sent_2, 254 | e1_char_start, e1_char_end, 255 | e2_char_start, e2_char_end, 256 | model, tokenizer 257 | ) 258 | predictions.append(pred) 259 | probabilities.append(prob) 260 | results.append({ 261 | "doc_id": sample['doc_id'], 262 | "document": sample['document'], 263 | "events": [ 264 | { 265 | 'start': char_start, 266 | 'end': char_end, 267 | 'trigger': trigger 268 | } for char_start, char_end, trigger in events 269 | ], 270 | "pred_label": predictions, 271 | "pred_prob": probabilities 272 | }) 273 | with open(os.path.join(args.output_dir, best_save_weight + '_test_pred_corefs.json'), 'wt', encoding='utf-8') as f: 274 | for exapmle_result in results: 275 | f.write(json.dumps(exapmle_result) + '\n') 276 | -------------------------------------------------------------------------------- /src/local_event_coref/run_local_base.sh: -------------------------------------------------------------------------------- 1 | export OUTPUT_DIR=./bert_results/ 2 | 3 | python3 run_local_base.py \ 4 | --output_dir=$OUTPUT_DIR \ 5 | --model_type=bert \ 6 | --model_checkpoint=../../../PT_MODELS/bert-large-cased/ \ 7 | --train_file=../../data/train_filtered.json \ 8 | --dev_file=../../data/dev_filtered.json \ 9 | --test_file=../../data/test_filtered.json \ 10 | --max_seq_length=512 \ 11 | --learning_rate=1e-5 \ 12 | --matching_style=multi \ 13 | --softmax_loss=ce \ 14 | --num_train_epochs=10 \ 15 | --batch_size=4 \ 16 | --do_train \ 17 | --warmup_proportion=0. \ 18 | --seed=42 -------------------------------------------------------------------------------- /src/local_event_coref/run_local_base_with_mask.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import torch 4 | import json 5 | from collections import namedtuple, defaultdict 6 | from tqdm.auto import tqdm 7 | from transformers import AutoConfig, AutoTokenizer 8 | from transformers import AdamW, get_scheduler 9 | from sklearn.metrics import classification_report 10 | import sys 11 | sys.path.append('../../') 12 | from src.tools import seed_everything, NpEncoder 13 | from src.local_event_coref.arg import parse_args 14 | from src.local_event_coref.data import KBPCorefPair, get_dataLoader, SUBTYPES 15 | from src.local_event_coref.modeling import BertForPairwiseECWithMask, RobertaForPairwiseECWithMask 16 | 17 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 18 | datefmt='%Y/%m/%d %H:%M:%S', 19 | level=logging.INFO) 20 | logger = logging.getLogger("Model") 21 | Sentence = namedtuple("Sentence", ["start", "text"]) 22 | 23 | MODEL_CLASSES = { 24 | 'bert': BertForPairwiseECWithMask, 25 | 'spanbert': BertForPairwiseECWithMask, 26 | 'roberta': RobertaForPairwiseECWithMask 27 | } 28 | 29 | def to_device(args, batch_data): 30 | new_batch_data = {} 31 | for k, v in batch_data.items(): 32 | if k in ['batch_inputs', 'batch_inputs_with_mask']: 33 | new_batch_data[k] = { 34 | k_: v_.to(args.device) for k_, v_ in v.items() 35 | } 36 | else: 37 | new_batch_data[k] = torch.tensor(v).to(args.device) 38 | return new_batch_data 39 | 40 | def train_loop(args, dataloader, model, optimizer, lr_scheduler, epoch, total_loss): 41 | progress_bar = tqdm(range(len(dataloader))) 42 | progress_bar.set_description(f'loss: {0:>7f}') 43 | finish_step_num = epoch * len(dataloader) 44 | 45 | model.train() 46 | for step, batch_data in enumerate(dataloader, start=1): 47 | batch_data = to_device(args, batch_data) 48 | outputs = model(**batch_data) 49 | loss = outputs[0] 50 | 51 | optimizer.zero_grad() 52 | loss.backward() 53 | optimizer.step() 54 | lr_scheduler.step() 55 | 56 | total_loss += loss.item() 57 | progress_bar.set_description(f'loss: {total_loss/(finish_step_num + step):>7f}') 58 | progress_bar.update(1) 59 | return total_loss 60 | 61 | def test_loop(args, dataloader, model): 62 | true_labels, true_predictions = [], [] 63 | model.eval() 64 | with torch.no_grad(): 65 | for batch_data in tqdm(dataloader): 66 | batch_data = to_device(args, batch_data) 67 | outputs = model(**batch_data) 68 | logits = outputs[1] 69 | 70 | predictions = logits.argmax(dim=-1).cpu().numpy().tolist() 71 | labels = batch_data['labels'].cpu().numpy() 72 | true_predictions += predictions 73 | true_labels += [int(label) for label in labels] 74 | return classification_report(true_labels, true_predictions, output_dict=True) 75 | 76 | def train(args, train_dataset, dev_dataset, model, tokenizer): 77 | """ Train the model """ 78 | train_dataloader = get_dataLoader(args, train_dataset, tokenizer, shuffle=True, collote_fn_type='with_mask') 79 | dev_dataloader = get_dataLoader(args, dev_dataset, tokenizer, shuffle=False, collote_fn_type='with_mask') 80 | t_total = len(train_dataloader) * args.num_train_epochs 81 | # Prepare optimizer and schedule (linear warmup and decay) 82 | no_decay = ["bias", "LayerNorm.weight"] 83 | optimizer_grouped_parameters = [ 84 | {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay}, 85 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0} 86 | ] 87 | args.warmup_steps = int(t_total * args.warmup_proportion) 88 | optimizer = AdamW( 89 | optimizer_grouped_parameters, 90 | lr=args.learning_rate, 91 | betas=(args.adam_beta1, args.adam_beta2), 92 | eps=args.adam_epsilon 93 | ) 94 | lr_scheduler = get_scheduler( 95 | 'linear', 96 | optimizer, 97 | num_warmup_steps=args.warmup_steps, 98 | num_training_steps=t_total 99 | ) 100 | # Train! 101 | logger.info("***** Running training *****") 102 | logger.info(f"Num examples - {len(train_dataset)}") 103 | logger.info(f"Num Epochs - {args.num_train_epochs}") 104 | logger.info(f"Total optimization steps - {t_total}") 105 | with open(os.path.join(args.output_dir, 'args.txt'), 'wt') as f: 106 | f.write(str(args)) 107 | 108 | total_loss = 0. 109 | best_f1 = 0. 110 | for epoch in range(args.num_train_epochs): 111 | print(f"Epoch {epoch+1}/{args.num_train_epochs}\n-------------------------------") 112 | total_loss = train_loop(args, train_dataloader, model, optimizer, lr_scheduler, epoch, total_loss) 113 | metrics = test_loop(args, dev_dataloader, model) 114 | dev_p, dev_r, dev_f1 = metrics['1']['precision'], metrics['1']['recall'], metrics['1']['f1-score'] 115 | logger.info(f'Dev: P - {(100*dev_p):0.4f} R - {(100*dev_r):0.4f} F1 - {(100*dev_f1):0.4f}') 116 | if dev_f1 > best_f1: 117 | best_f1 = dev_f1 118 | logger.info(f'saving new weights to {args.output_dir}...\n') 119 | save_weight = f'epoch_{epoch+1}_dev_f1_{(100*dev_f1):0.4f}_weights.bin' 120 | torch.save(model.state_dict(), os.path.join(args.output_dir, save_weight)) 121 | with open(os.path.join(args.output_dir, 'dev_metrics.txt'), 'at') as f: 122 | f.write(f'epoch_{epoch+1}\n' + json.dumps(metrics, cls=NpEncoder) + '\n\n') 123 | logger.info("Done!") 124 | 125 | def test(args, test_dataset, model, tokenizer, save_weights:list): 126 | test_dataloader = get_dataLoader(args, test_dataset, tokenizer, batch_size=1, shuffle=False, collote_fn_type='with_mask') 127 | logger.info('***** Running testing *****') 128 | for save_weight in save_weights: 129 | logger.info(f'loading weights from {save_weight}...') 130 | model.load_state_dict(torch.load(os.path.join(args.output_dir, save_weight))) 131 | metrics = test_loop(args, test_dataloader, model) 132 | with open(os.path.join(args.output_dir, 'test_metrics.txt'), 'at') as f: 133 | f.write(f'{save_weight}\n{json.dumps(metrics, cls=NpEncoder)}\n\n') 134 | 135 | def predict(args, sent_1, sent_2, e1_char_start, e1_char_end, e2_char_start, e2_char_end, model, tokenizer): 136 | 137 | def _cut_sent(sent, e_char_start, e_char_end, max_length): 138 | before = ' '.join([c for c in sent[:e_char_start].split(' ') if c != ''][-max_length:]).strip() 139 | trigger = sent[e_char_start:e_char_end+1] 140 | after = ' '.join([c for c in sent[e_char_end+1:].split(' ') if c != ''][:max_length]).strip() 141 | new_sent, new_char_start, new_char_end = before + ' ' + trigger + ' ' + after, len(before) + 1, len(before) + len(trigger) 142 | assert new_sent[new_char_start:new_char_end+1] == trigger 143 | return new_sent, new_char_start, new_char_end 144 | 145 | max_mention_length = (args.max_seq_length - 50) // 4 146 | sent_1, e1_char_start, e1_char_end = _cut_sent(sent_1, e1_char_start, e1_char_end, max_mention_length) 147 | sent_2, e2_char_start, e2_char_end = _cut_sent(sent_2, e2_char_start, e2_char_end, max_mention_length) 148 | inputs = tokenizer( 149 | sent_1, 150 | sent_2, 151 | max_length=args.max_seq_length, 152 | truncation=True, 153 | return_tensors="pt" 154 | ) 155 | inputs_with_mask = tokenizer( 156 | sent_1, 157 | sent_2, 158 | max_length=args.max_seq_length, 159 | truncation=True, 160 | return_tensors="pt" 161 | ) 162 | e1_token_start = inputs.char_to_token(e1_char_start, sequence_index=0) 163 | if not e1_token_start: 164 | e1_token_start = inputs.char_to_token(e1_char_start + 1, sequence_index=0) 165 | e1_token_end = inputs.char_to_token(e1_char_end, sequence_index=0) 166 | e2_token_start = inputs.char_to_token(e2_char_start, sequence_index=1) 167 | if not e2_token_start: 168 | e2_token_start = inputs.char_to_token(e2_char_start + 1, sequence_index=1) 169 | e2_token_end = inputs.char_to_token(e2_char_end, sequence_index=1) 170 | assert e1_token_start and e1_token_end and e2_token_start and e2_token_end 171 | inputs_with_mask['input_ids'][0][e1_token_start:e1_token_end+1] = tokenizer.mask_token_id 172 | inputs_with_mask['input_ids'][0][e2_token_start:e2_token_end+1] = tokenizer.mask_token_id 173 | inputs = { 174 | 'batch_inputs': inputs, 175 | 'batch_inputs_with_mask': inputs_with_mask, 176 | 'batch_e1_idx': [[[e1_token_start, e1_token_end]]], 177 | 'batch_e2_idx': [[[e2_token_start, e2_token_end]]] 178 | } 179 | inputs = to_device(args, inputs) 180 | with torch.no_grad(): 181 | outputs = model(**inputs) 182 | logits = outputs[1] 183 | pred = int(logits.argmax(dim=-1)[0].cpu().numpy()) 184 | prob = torch.nn.functional.softmax(logits, dim=-1)[0].cpu().numpy().tolist() 185 | return pred, prob[pred] 186 | 187 | def get_event_sent(e_start, e_end, sents): 188 | for sent in sents: 189 | sent_end = sent.start + len(sent.text) - 1 190 | if e_start >= sent.start and e_end <= sent_end: 191 | return sent.text, e_start - sent.start, e_end - sent.start 192 | return None, None, None 193 | 194 | if __name__ == '__main__': 195 | args = parse_args() 196 | if args.do_train and os.path.exists(args.output_dir) and os.listdir(args.output_dir): 197 | raise ValueError(f'Output directory ({args.output_dir}) already exists and is not empty.') 198 | if not os.path.exists(args.output_dir): 199 | os.mkdir(args.output_dir) 200 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 201 | args.n_gpu = torch.cuda.device_count() 202 | logger.warning(f'Using {args.device} device, n_gpu: {args.n_gpu}') 203 | # Set seed 204 | seed_everything(args.seed) 205 | # Load pretrained model and tokenizer 206 | logger.info(f'loading pretrained model and tokenizer of {args.model_type} ...') 207 | config = AutoConfig.from_pretrained(args.model_checkpoint, cache_dir=args.cache_dir) 208 | tokenizer = AutoTokenizer.from_pretrained(args.model_checkpoint, cache_dir=args.cache_dir) 209 | args.num_labels = 2 210 | args.num_subtypes = len(SUBTYPES) + 1 211 | model = MODEL_CLASSES[args.model_type].from_pretrained( 212 | args.model_checkpoint, 213 | config=config, 214 | cache_dir=args.cache_dir, 215 | args=args 216 | ).to(args.device) 217 | # Training 218 | if args.do_train: 219 | train_dataset = KBPCorefPair(args.train_file) 220 | dev_dataset = KBPCorefPair(args.dev_file) 221 | train(args, train_dataset, dev_dataset, model, tokenizer) 222 | # Testing 223 | save_weights = [file for file in os.listdir(args.output_dir) if file.endswith('.bin')] 224 | if args.do_test: 225 | test_dataset = KBPCorefPair(args.test_file) 226 | test(args, test_dataset, model, tokenizer, save_weights) 227 | # Predicting 228 | if args.do_predict: 229 | kbp_sent_dic = defaultdict(list) # {filename: [Sentence]} 230 | with open(os.path.join('../../data/kbp_sent.txt'), 'rt', encoding='utf-8') as sents: 231 | for line in sents: 232 | doc_id, start, text = line.strip().split('\t') 233 | kbp_sent_dic[doc_id].append(Sentence(int(start), text)) 234 | 235 | pred_event_file = 'epoch_3_dev_f1_57.9994_weights.bin_test_pred_events.json' 236 | 237 | for best_save_weight in save_weights: 238 | logger.info(f'loading weights from {best_save_weight}...') 239 | model.load_state_dict(torch.load(os.path.join(args.output_dir, best_save_weight))) 240 | logger.info(f'predicting coref labels of {best_save_weight}...') 241 | 242 | results = [] 243 | model.eval() 244 | pred_event_filepath = os.path.join(args.output_dir, pred_event_file) 245 | with open(pred_event_filepath, 'rt' , encoding='utf-8') as f_in: 246 | for line in tqdm(f_in.readlines()): 247 | sample = json.loads(line.strip()) 248 | events = [ 249 | (event['start'], event['start'] + len(event['trigger']) - 1, event['trigger']) 250 | for event in sample['pred_label'] 251 | ] 252 | sents = kbp_sent_dic[sample['doc_id']] 253 | new_events = [] 254 | for e_start, e_end, e_trigger in events: 255 | e_sent, e_new_start, e_new_end = get_event_sent(e_start, e_end, sents) 256 | assert e_sent is not None and e_sent[e_new_start:e_new_end+1] == e_trigger 257 | new_events.append((e_new_start, e_new_end, e_sent)) 258 | predictions, probabilities = [], [] 259 | for i in range(len(new_events) - 1): 260 | for j in range(i + 1, len(new_events)): 261 | e1_char_start, e1_char_end, sent_1 = new_events[i] 262 | e2_char_start, e2_char_end, sent_2 = new_events[j] 263 | pred, prob = predict(args, 264 | sent_1, sent_2, 265 | e1_char_start, e1_char_end, 266 | e2_char_start, e2_char_end, 267 | model, tokenizer 268 | ) 269 | predictions.append(pred) 270 | probabilities.append(prob) 271 | results.append({ 272 | "doc_id": sample['doc_id'], 273 | "document": sample['document'], 274 | "events": [ 275 | { 276 | 'start': char_start, 277 | 'end': char_end, 278 | 'trigger': trigger 279 | } for char_start, char_end, trigger in events 280 | ], 281 | "pred_label": predictions, 282 | "pred_prob": probabilities 283 | }) 284 | with open(os.path.join(args.output_dir, best_save_weight + '_test_pred_corefs.json'), 'wt', encoding='utf-8') as f: 285 | for exapmle_result in results: 286 | f.write(json.dumps(exapmle_result) + '\n') 287 | -------------------------------------------------------------------------------- /src/local_event_coref/run_local_base_with_mask.sh: -------------------------------------------------------------------------------- 1 | export OUTPUT_DIR=./mask_bert_results/ 2 | 3 | python3 run_local_base_with_mask.py \ 4 | --output_dir=$OUTPUT_DIR \ 5 | --model_type=bert \ 6 | --model_checkpoint=../../../PT_MODELS/bert-large-cased/ \ 7 | --train_file=../../data/train_filtered.json \ 8 | --dev_file=../../data/dev_filtered.json \ 9 | --test_file=../../data/test_filtered.json \ 10 | --max_seq_length=512 \ 11 | --learning_rate=1e-5 \ 12 | --matching_style=multi \ 13 | --softmax_loss=ce \ 14 | --num_train_epochs=10 \ 15 | --batch_size=4 \ 16 | --do_train \ 17 | --warmup_proportion=0. \ 18 | --seed=42 -------------------------------------------------------------------------------- /src/local_event_coref/run_local_base_with_mask_topic.sh: -------------------------------------------------------------------------------- 1 | export OUTPUT_DIR=./mask_topic_bert_results/ 2 | 3 | python3 run_local_base_with_mask_topic.py \ 4 | --output_dir=$OUTPUT_DIR \ 5 | --model_type=bert \ 6 | --model_checkpoint=../../../PT_MODELS/bert-large-cased/ \ 7 | --topic_model=vmf \ 8 | --topic_dim=32 \ 9 | --topic_inter_map=64 \ 10 | --train_file=../../data/train_filtered.json \ 11 | --dev_file=../../data/dev_filtered.json \ 12 | --test_file=../../data/test_filtered.json \ 13 | --max_seq_length=512 \ 14 | --learning_rate=1e-5 \ 15 | --matching_style=multi \ 16 | --softmax_loss=ce \ 17 | --num_train_epochs=10 \ 18 | --batch_size=4 \ 19 | --do_train \ 20 | --warmup_proportion=0. \ 21 | --seed=42 -------------------------------------------------------------------------------- /src/local_event_coref/run_local_base_with_topic.sh: -------------------------------------------------------------------------------- 1 | export OUTPUT_DIR=./topic_bert_results/ 2 | 3 | python3 run_local_base_with_topic.py \ 4 | --output_dir=$OUTPUT_DIR \ 5 | --model_type=bert \ 6 | --model_checkpoint=../../../PT_MODELS/bert-large-cased/ \ 7 | --topic_model=vmf \ 8 | --topic_dim=32 \ 9 | --topic_inter_map=64 \ 10 | --train_file=../../data/train_filtered.json \ 11 | --dev_file=../../data/dev_filtered.json \ 12 | --test_file=../../data/test_filtered.json \ 13 | --max_seq_length=512 \ 14 | --learning_rate=1e-5 \ 15 | --matching_style=multi \ 16 | --softmax_loss=ce \ 17 | --num_train_epochs=10 \ 18 | --batch_size=4 \ 19 | --do_train \ 20 | --warmup_proportion=0. \ 21 | --seed=42 -------------------------------------------------------------------------------- /src/trigger_detection/arg.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_args(): 4 | parser = argparse.ArgumentParser() 5 | 6 | # Required parameters 7 | parser.add_argument("--output_dir", default=None, type=str, required=True, 8 | help="The output directory where the model checkpoints and predictions will be written.", 9 | ) 10 | parser.add_argument("--train_file", default=None, type=str, required=True, help="The input training file.") 11 | parser.add_argument("--dev_file", default=None, type=str, required=True, help="The input evaluation file.") 12 | parser.add_argument("--test_file", default=None, type=str, required=True, help="The input testing file.") 13 | 14 | parser.add_argument("--model_type", 15 | default="longformer", type=str, required=True 16 | ) 17 | parser.add_argument("--model_checkpoint", 18 | default="allenai/longformer-base-4096", type=str, required=True, 19 | help="Path to pretrained model or model identifier from huggingface.co/models", 20 | ) 21 | parser.add_argument("--max_seq_length", default=4096, type=int, required=True) 22 | 23 | parser.add_argument("--do_train", action="store_true", help="Whether to run training.") 24 | parser.add_argument("--do_test", action="store_true", help="Whether to run eval on the test set.") 25 | parser.add_argument("--do_predict", action="store_true", help="Whether to save predicted labels.") 26 | 27 | # Other parameters 28 | parser.add_argument("--cache_dir", default=None, type=str, 29 | help="Where do you want to store the pre-trained models downloaded from s3" 30 | ) 31 | parser.add_argument("--use_ffnn_layer", action="store_true", help="Whether add FFNN before classifier.") 32 | parser.add_argument("--ffnn_size", default=-1, type=int, help="The size of mlp layer.") 33 | parser.add_argument("--softmax_loss", default='ce', type=str, 34 | help="The loss function for softmax model.", 35 | choices=['lsr', 'focal', 'ce'] 36 | ) 37 | parser.add_argument("--learning_rate", default=1e-5, type=float, help="The initial learning rate for Adam.") 38 | parser.add_argument("--crf_learning_rate", default=5e-5, type=float, help="The initial learning rate for crf.") 39 | parser.add_argument("--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform.") 40 | parser.add_argument("--batch_size", default=4, type=int) 41 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") 42 | 43 | parser.add_argument("--adam_beta1", default=0.9, type=float, 44 | help="Epsilon for Adam optimizer." 45 | ) 46 | parser.add_argument("--adam_beta2", default=0.98, type=float, 47 | help="Epsilon for Adam optimizer." 48 | ) 49 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 50 | help="Epsilon for Adam optimizer." 51 | ) 52 | parser.add_argument("--warmup_proportion", default=0.1, type=float, 53 | help="Proportion of training to perform linear learning rate warmup for,E.g., 0.1 = 10% of training." 54 | ) 55 | parser.add_argument("--weight_decay", default=0.01, type=float, 56 | help="Weight decay if we apply some." 57 | ) 58 | args = parser.parse_args() 59 | return args -------------------------------------------------------------------------------- /src/trigger_detection/data.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | import json 3 | import numpy as np 4 | import torch 5 | 6 | CATEGORIES = [ 7 | 'artifact', 'transferownership', 'transaction', 'broadcast', 'contact', 'demonstrate', \ 8 | 'injure', 'transfermoney', 'transportartifact', 'attack', 'meet', 'elect', \ 9 | 'endposition', 'correspondence', 'arrestjail', 'startposition', 'transportperson', 'die' 10 | ] 11 | 12 | class KBPTrigger(Dataset): 13 | def __init__(self, data_file): 14 | self.data = self.load_data(data_file) 15 | 16 | def load_data(self, data_file): 17 | Data = [] 18 | with open(data_file, 'rt', encoding='utf-8') as f: 19 | for line in f: 20 | sample = json.loads(line.strip()) 21 | tags = [ 22 | (event['start'], event['start'] + len(event['trigger']) - 1, event['trigger'], event['subtype']) 23 | for event in sample['events'] if event['subtype'] in CATEGORIES 24 | ] 25 | Data.append({ 26 | 'id': sample['doc_id'], 27 | 'document': sample['document'], 28 | 'tags': tags 29 | }) 30 | return Data 31 | 32 | def __len__(self): 33 | return len(self.data) 34 | 35 | def __getitem__(self, idx): 36 | return self.data[idx] 37 | 38 | def get_dataLoader(args, dataset, tokenizer, batch_size=None, shuffle=False): 39 | 40 | def collote_fn(batch_samples): 41 | batch_sentence, batch_tags = [], [] 42 | for sample in batch_samples: 43 | batch_sentence.append(sample['document']) 44 | batch_tags.append(sample['tags']) 45 | batch_inputs = tokenizer( 46 | batch_sentence, 47 | max_length=args.max_seq_length, 48 | padding=True, 49 | truncation=True, 50 | return_tensors="pt" 51 | ) 52 | batch_label = np.zeros(batch_inputs['input_ids'].shape, dtype=int) 53 | for s_idx, sentence in enumerate(batch_sentence): 54 | encoding = tokenizer(sentence, max_length=args.max_seq_length, truncation=True) 55 | for char_start, char_end, _, tag in batch_tags[s_idx]: 56 | token_start = encoding.char_to_token(char_start) 57 | token_end = encoding.char_to_token(char_end) 58 | if not token_start or not token_end: 59 | continue 60 | batch_label[s_idx][token_start] = args.label2id[f"B-{tag}"] 61 | batch_label[s_idx][token_start + 1:token_end + 1] = args.label2id[f"I-{tag}"] 62 | batch_inputs['labels'] = torch.tensor(batch_label) 63 | return batch_inputs 64 | 65 | return DataLoader(dataset, batch_size=(batch_size if batch_size else args.batch_size), shuffle=shuffle, 66 | collate_fn=collote_fn) -------------------------------------------------------------------------------- /src/trigger_detection/modeling.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn import CrossEntropyLoss 3 | from transformers import LongformerPreTrainedModel, LongformerModel 4 | from ..tools import LabelSmoothingCrossEntropy, FocalLoss, CRF 5 | from ..tools import FullyConnectedLayer 6 | 7 | class LongformerSoftmaxForTD(LongformerPreTrainedModel): 8 | def __init__(self, config, args): 9 | super().__init__(config) 10 | self.num_labels = args.num_labels 11 | self.longformer = LongformerModel(config, add_pooling_layer=False) 12 | self.use_ffnn_layer = args.use_ffnn_layer 13 | if self.use_ffnn_layer: 14 | self.ffnn_size = args.ffnn_size if args.ffnn_size != -1 else config.hidden_size 15 | self.mlp = FullyConnectedLayer(config, config.hidden_size, self.ffnn_size, config.hidden_dropout_prob) 16 | else: 17 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 18 | self.classifier = nn.Linear(self.ffnn_size if args.use_ffnn_layer else config.hidden_size, self.num_labels) 19 | self.loss_type = args.softmax_loss 20 | self.post_init() 21 | 22 | def forward(self, input_ids, attention_mask, labels=None): 23 | outputs = self.longformer(input_ids, attention_mask=attention_mask) 24 | sequence_output = outputs[0] 25 | if self.use_ffnn_layer: 26 | sequence_output = self.mlp(sequence_output) 27 | else: 28 | sequence_output = self.dropout(sequence_output) 29 | logits = self.classifier(sequence_output) 30 | 31 | loss = None 32 | if labels is not None: 33 | assert self.loss_type in ['lsr', 'focal', 'ce'] 34 | if self.loss_type == 'lsr': 35 | loss_fct = LabelSmoothingCrossEntropy() 36 | elif self.loss_type == 'focal': 37 | loss_fct = FocalLoss() 38 | else: 39 | loss_fct = CrossEntropyLoss() 40 | # Only keep active parts of the loss 41 | if attention_mask is not None: 42 | active_loss = attention_mask.view(-1) == 1 43 | active_logits = logits.view(-1, self.num_labels)[active_loss] 44 | active_labels = labels.view(-1)[active_loss] 45 | loss = loss_fct(active_logits, active_labels) 46 | else: 47 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 48 | return loss, logits 49 | 50 | class LongformerCrfForTD(LongformerPreTrainedModel): 51 | def __init__(self, config, args): 52 | super().__init__(config) 53 | self.num_labels = args.num_labels 54 | self.longformer = LongformerModel(config, add_pooling_layer=False) 55 | self.use_ffnn_layer = args.use_ffnn_layer 56 | if self.use_ffnn_layer: 57 | self.ffnn_size = args.ffnn_size if args.ffnn_size != -1 else config.hidden_size 58 | self.mlp = FullyConnectedLayer(config, config.hidden_size, self.ffnn_size, config.hidden_dropout_prob) 59 | else: 60 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 61 | self.classifier = nn.Linear(self.ffnn_size if args.use_ffnn_layer else config.hidden_size, self.num_labels) 62 | self.crf = CRF(num_tags=self.num_labels, batch_first=True) 63 | self.post_init() 64 | 65 | def forward(self, input_ids, attention_mask, labels=None): 66 | outputs = self.longformer(input_ids, attention_mask=attention_mask) 67 | sequence_output = outputs[0] 68 | if self.use_ffnn_layer: 69 | sequence_output = self.mlp(sequence_output) 70 | else: 71 | sequence_output = self.dropout(sequence_output) 72 | logits = self.classifier(sequence_output) 73 | 74 | loss = None 75 | if labels is not None: 76 | loss = -1 * self.crf(emissions=logits, tags=labels, mask=attention_mask) 77 | return loss, logits 78 | -------------------------------------------------------------------------------- /src/trigger_detection/run_td_crf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import logging 4 | from tqdm.auto import tqdm 5 | import numpy as np 6 | import torch 7 | from transformers import AutoConfig, AutoTokenizer 8 | from transformers import AdamW, get_scheduler 9 | from seqeval.metrics import classification_report 10 | from seqeval.scheme import IOB2 11 | import sys 12 | sys.path.append('../../') 13 | from src.trigger_detection.data import KBPTrigger, get_dataLoader, CATEGORIES 14 | from src.trigger_detection.modeling import LongformerCrfForTD 15 | from src.trigger_detection.arg import parse_args 16 | from src.tools import seed_everything, NpEncoder 17 | 18 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 19 | datefmt='%Y/%m/%d %H:%M:%S', 20 | level=logging.INFO) 21 | logger = logging.getLogger("Model") 22 | 23 | def train_loop(args, dataloader, model, optimizer, lr_scheduler, epoch, total_loss): 24 | progress_bar = tqdm(range(len(dataloader))) 25 | progress_bar.set_description(f'loss: {0:>7f}') 26 | finish_step_num = epoch * len(dataloader) 27 | 28 | model.train() 29 | for step, batch_data in enumerate(dataloader, start=1): 30 | batch_data = batch_data.to(args.device) 31 | outputs = model(**batch_data) 32 | loss = outputs[0] 33 | 34 | optimizer.zero_grad() 35 | loss.backward() 36 | optimizer.step() 37 | lr_scheduler.step() 38 | 39 | total_loss += loss.item() 40 | progress_bar.set_description(f'loss: {total_loss/(finish_step_num + step):>7f}') 41 | progress_bar.update(1) 42 | return total_loss 43 | 44 | def test_loop(args, dataloader, model): 45 | true_labels, true_predictions = [], [] 46 | model.eval() 47 | with torch.no_grad(): 48 | for batch_data in tqdm(dataloader): 49 | batch_data = batch_data.to(args.device) 50 | outputs = model(**batch_data) 51 | logits = outputs[1] 52 | tags = model.crf.decode(logits, batch_data['attention_mask']) 53 | predictions = tags.squeeze(0).cpu().numpy() 54 | labels = batch_data['labels'].cpu().numpy() 55 | lens = np.sum(batch_data['attention_mask'].cpu().numpy(), axis=-1) 56 | true_labels += [ 57 | [args.id2label[int(l)] for idx, l in enumerate(label) if idx > 0 and idx < seq_len - 1] 58 | for label, seq_len in zip(labels, lens) 59 | ] 60 | true_predictions += [ 61 | [args.id2label[int(p)] for idx, p in enumerate(prediction) if idx > 0 and idx < seq_len - 1] 62 | for prediction, seq_len in zip(predictions, lens) 63 | ] 64 | return classification_report(true_labels, true_predictions, mode='strict', scheme=IOB2, output_dict=True) 65 | 66 | def train(args, train_dataset, dev_dataset, model, tokenizer): 67 | """ Train the model """ 68 | train_dataloader = get_dataLoader(args, train_dataset, tokenizer, shuffle=True) 69 | dev_dataloader = get_dataLoader(args, dev_dataset, tokenizer, shuffle=False) 70 | t_total = len(train_dataloader) * args.num_train_epochs 71 | # Prepare optimizer and schedule (linear warmup and decay) 72 | no_decay = ["bias", "LayerNorm.weight"] 73 | longformer_param_optimizer = list(model.longformer.named_parameters()) 74 | crf_param_optimizer = list(model.crf.named_parameters()) 75 | linear_param_optimizer = list(model.classifier.named_parameters()) 76 | optimizer_grouped_parameters = [ 77 | {'params': [p for n, p in longformer_param_optimizer if not any(nd in n for nd in no_decay)], 78 | 'weight_decay': args.weight_decay, 'lr': args.learning_rate}, 79 | {'params': [p for n, p in longformer_param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 80 | 'lr': args.learning_rate}, 81 | 82 | {'params': [p for n, p in crf_param_optimizer if not any(nd in n for nd in no_decay)], 83 | 'weight_decay': args.weight_decay, 'lr': args.crf_learning_rate}, 84 | {'params': [p for n, p in crf_param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 85 | 'lr': args.crf_learning_rate}, 86 | 87 | {'params': [p for n, p in linear_param_optimizer if not any(nd in n for nd in no_decay)], 88 | 'weight_decay': args.weight_decay, 'lr': args.crf_learning_rate}, 89 | {'params': [p for n, p in linear_param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 90 | 'lr': args.crf_learning_rate} 91 | ] 92 | args.warmup_steps = int(t_total * args.warmup_proportion) 93 | optimizer = AdamW( 94 | optimizer_grouped_parameters, 95 | lr=args.learning_rate, 96 | betas=(args.adam_beta1, args.adam_beta2), 97 | eps=args.adam_epsilon 98 | ) 99 | lr_scheduler = get_scheduler( 100 | 'linear', 101 | optimizer, 102 | num_warmup_steps=args.warmup_steps, 103 | num_training_steps=t_total 104 | ) 105 | # Train! 106 | logger.info("***** Running training *****") 107 | logger.info(f"Num examples - {len(train_dataset)}") 108 | logger.info(f"Num Epochs - {args.num_train_epochs}") 109 | logger.info(f"Total optimization steps - {t_total}") 110 | with open(os.path.join(args.output_dir, 'args.txt'), 'wt') as f: 111 | f.write(str(args)) 112 | 113 | total_loss = 0. 114 | best_f1 = 0. 115 | for epoch in range(args.num_train_epochs): 116 | print(f"Epoch {epoch+1}/{args.num_train_epochs}\n-------------------------------") 117 | total_loss = train_loop(args, train_dataloader, model, optimizer, lr_scheduler, epoch, total_loss) 118 | metrics = test_loop(args, dev_dataloader, model) 119 | micro_f1, macro_f1 = metrics['micro avg']['f1-score'], metrics['macro avg']['f1-score'] 120 | dev_f1 = metrics['weighted avg']['f1-score'] 121 | logger.info(f'Dev: micro_F1 - {(100*micro_f1):0.4f} macro_f1 - {(100*macro_f1):0.4f} weighted_f1 - {(100*dev_f1):0.4f}') 122 | if dev_f1 > best_f1: 123 | best_f1 = dev_f1 124 | logger.info(f'saving new weights to {args.output_dir}...\n') 125 | save_weight = f'epoch_{epoch+1}_dev_f1_{(100*dev_f1):0.4f}_weights.bin' 126 | torch.save(model.state_dict(), os.path.join(args.output_dir, save_weight)) 127 | with open(os.path.join(args.output_dir, 'dev_metrics.txt'), 'at') as f: 128 | f.write(f'epoch_{epoch+1}\n' + json.dumps(metrics, cls=NpEncoder) + '\n\n') 129 | logger.info("Done!") 130 | 131 | def predict(args, document:str, model, tokenizer): 132 | inputs = tokenizer( 133 | document, 134 | max_length=args.max_seq_length, 135 | truncation=True, 136 | return_tensors="pt", 137 | return_offsets_mapping=True 138 | ) 139 | offsets = inputs.pop('offset_mapping').squeeze(0) 140 | inputs = inputs.to(args.device) 141 | with torch.no_grad(): 142 | outputs = model(**inputs) 143 | logits = outputs[1] 144 | probabilities = torch.nn.functional.softmax(logits, dim=-1)[0].cpu().numpy().tolist() 145 | predictions = model.crf.decode(logits, inputs['attention_mask']) 146 | predictions = predictions.squeeze(0)[0].cpu().numpy().tolist() 147 | 148 | pred_label = [] 149 | idx = 1 150 | while idx < len(predictions) - 1: 151 | pred = predictions[idx] 152 | label = args.id2label[pred] 153 | if label != "O": 154 | label = label[2:] # Remove the B- or I- 155 | start, end = offsets[idx] 156 | all_scores = [probabilities[idx][pred]] 157 | # Grab all the tokens labeled with I-label 158 | while ( 159 | idx + 1 < len(predictions) - 1 and 160 | args.id2label[predictions[idx + 1]] == f"I-{label}" 161 | ): 162 | all_scores.append(probabilities[idx + 1][predictions[idx + 1]]) 163 | _, end = offsets[idx + 1] 164 | idx += 1 165 | 166 | score = np.mean(all_scores).item() 167 | start, end = start.item(), end.item() 168 | word = document[start:end] 169 | pred_label.append({ 170 | "trigger": word, 171 | "start": start, 172 | "subtype": label, 173 | "score": score 174 | }) 175 | idx += 1 176 | return pred_label 177 | 178 | def test(args, test_dataset, model, tokenizer, save_weights:list): 179 | test_dataloader = get_dataLoader(args, test_dataset, tokenizer, batch_size=1, shuffle=False) 180 | logger.info('***** Running testing *****') 181 | for save_weight in save_weights: 182 | logger.info(f'loading weights from {save_weight}...') 183 | model.load_state_dict(torch.load(os.path.join(args.output_dir, save_weight))) 184 | metrics = test_loop(args, test_dataloader, model) 185 | with open(os.path.join(args.output_dir, 'test_metrics.txt'), 'at') as f: 186 | f.write(f'{save_weight}\n{json.dumps(metrics, cls=NpEncoder)}\n\n') 187 | if args.do_predict: 188 | logger.info(f'predicting labels of {save_weight}...') 189 | results = [] 190 | model.eval() 191 | for sample in tqdm(test_dataset): 192 | pred_label = predict(args, sample['document'], model, tokenizer) 193 | results.append({ 194 | "doc_id": sample['id'], 195 | "document": sample['document'], 196 | "pred_label": pred_label, 197 | "true_label": sample['tags'] 198 | }) 199 | with open(os.path.join(args.output_dir, save_weight + '_test_pred_events.json'), 'wt', encoding='utf-8') as f: 200 | for exapmle_result in results: 201 | f.write(json.dumps(exapmle_result) + '\n') 202 | 203 | if __name__ == '__main__': 204 | args = parse_args() 205 | if args.do_train and os.path.exists(args.output_dir) and os.listdir(args.output_dir): 206 | raise ValueError( 207 | f'Output directory ({args.output_dir}) already exists and is not empty.') 208 | if not os.path.exists(args.output_dir): 209 | os.mkdir(args.output_dir) 210 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 211 | args.n_gpu = torch.cuda.device_count() 212 | logger.warning(f'Using {args.device} device, n_gpu: {args.n_gpu}') 213 | # Set seed 214 | seed_everything(args.seed) 215 | # Prepare task 216 | args.id2label = {0:'O'} 217 | for c in CATEGORIES: 218 | args.id2label[len(args.id2label)] = f"B-{c}" 219 | args.id2label[len(args.id2label)] = f"I-{c}" 220 | args.label2id = {v: k for k, v in args.id2label.items()} 221 | args.num_labels = len(args.id2label) 222 | # Load pretrained model and tokenizer 223 | logger.info(f'loading pretrained model and tokenizer of {args.model_type} ...') 224 | config = AutoConfig.from_pretrained( 225 | args.model_checkpoint, 226 | cache_dir=args.cache_dir 227 | ) 228 | tokenizer = AutoTokenizer.from_pretrained(args.model_checkpoint, cache_dir=args.cache_dir) 229 | model = LongformerCrfForTD.from_pretrained( 230 | args.model_checkpoint, 231 | config=config, 232 | cache_dir=args.cache_dir, 233 | args=args 234 | ).to(args.device) 235 | # Training 236 | if args.do_train: 237 | logger.info(f'Training/evaluation parameters: {args}') 238 | train_dataset = KBPTrigger(args.train_file) 239 | dev_dataset = KBPTrigger(args.dev_file) 240 | train(args, train_dataset, dev_dataset, model, tokenizer) 241 | # Testing 242 | save_weights = [file for file in os.listdir(args.output_dir) if file.endswith('.bin')] 243 | if args.do_test: 244 | test_dataset = KBPTrigger(args.test_file) 245 | test(args, test_dataset, model, tokenizer, save_weights) 246 | # Predicting 247 | if args.do_predict: 248 | for save_weight in save_weights: 249 | logger.info(f'loading weights from {save_weight}...') 250 | model.load_state_dict(torch.load(os.path.join(args.output_dir, save_weight))) 251 | logger.info(f'predicting labels of {save_weight}...') 252 | 253 | results = [] 254 | model.eval() 255 | for sample in tqdm(test_dataset): 256 | pred_label = predict(args, sample['document'], model, tokenizer) 257 | results.append({ 258 | "doc_id": sample['id'], 259 | "document": sample['document'], 260 | "pred_label": pred_label, 261 | "true_label": sample['tags'] 262 | }) 263 | with open(os.path.join(args.output_dir, save_weight + '_test_pred_events.json'), 'wt', encoding='utf-8') as f: 264 | for exapmle_result in results: 265 | f.write(json.dumps(exapmle_result) + '\n') -------------------------------------------------------------------------------- /src/trigger_detection/run_td_crf.sh: -------------------------------------------------------------------------------- 1 | export OUTPUT_DIR=./crf_results/ 2 | 3 | python3 run_td_crf.py \ 4 | --output_dir=$OUTPUT_DIR \ 5 | --model_type=longformer \ 6 | --model_checkpoint=../../PT_MODELS/allenai/longformer-large-4096/ \ 7 | --train_file=../../data/train_filtered.json \ 8 | --dev_file=../../data/dev_filtered.json \ 9 | --test_file=../../data/test_filtered.json \ 10 | --max_seq_length=4096 \ 11 | --learning_rate=1e-5 \ 12 | --crf_learning_rate=5e-5 \ 13 | --num_train_epochs=50 \ 14 | --batch_size=1 \ 15 | --do_train \ 16 | --warmup_proportion=0. \ 17 | --seed=42 -------------------------------------------------------------------------------- /src/trigger_detection/run_td_softmax.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import logging 4 | from tqdm.auto import tqdm 5 | import numpy as np 6 | import torch 7 | from transformers import AutoConfig, AutoTokenizer 8 | from transformers import AdamW, get_scheduler 9 | from seqeval.metrics import classification_report 10 | from seqeval.scheme import IOB2 11 | import sys 12 | sys.path.append('../../') 13 | from src.trigger_detection.data import KBPTrigger, get_dataLoader, CATEGORIES 14 | from src.trigger_detection.modeling import LongformerSoftmaxForTD 15 | from src.trigger_detection.arg import parse_args 16 | from src.tools import seed_everything, NpEncoder 17 | 18 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 19 | datefmt='%Y/%m/%d %H:%M:%S', 20 | level=logging.INFO) 21 | logger = logging.getLogger("Model") 22 | 23 | def train_loop(args, dataloader, model, optimizer, lr_scheduler, epoch, total_loss): 24 | progress_bar = tqdm(range(len(dataloader))) 25 | progress_bar.set_description(f'loss: {0:>7f}') 26 | finish_step_num = epoch * len(dataloader) 27 | 28 | model.train() 29 | for step, batch_data in enumerate(dataloader, start=1): 30 | batch_data = batch_data.to(args.device) 31 | outputs = model(**batch_data) 32 | loss = outputs[0] 33 | 34 | optimizer.zero_grad() 35 | loss.backward() 36 | optimizer.step() 37 | lr_scheduler.step() 38 | 39 | total_loss += loss.item() 40 | progress_bar.set_description(f'loss: {total_loss/(finish_step_num + step):>7f}') 41 | progress_bar.update(1) 42 | return total_loss 43 | 44 | def test_loop(args, dataloader, model): 45 | true_labels, true_predictions = [], [] 46 | model.eval() 47 | with torch.no_grad(): 48 | for batch_data in tqdm(dataloader): 49 | batch_data = batch_data.to(args.device) 50 | outputs = model(**batch_data) 51 | logits = outputs[1] 52 | predictions = logits.argmax(dim=-1).cpu().numpy() # [batch, seq] 53 | labels = batch_data['labels'].cpu().numpy() 54 | lens = np.sum(batch_data['attention_mask'].cpu().numpy(), axis=-1) 55 | true_labels += [ 56 | [args.id2label[int(l)] for idx, l in enumerate(label) if idx > 0 and idx < seq_len - 1] 57 | for label, seq_len in zip(labels, lens) 58 | ] 59 | true_predictions += [ 60 | [args.id2label[int(p)] for idx, p in enumerate(prediction) if idx > 0 and idx < seq_len - 1] 61 | for prediction, seq_len in zip(predictions, lens) 62 | ] 63 | return classification_report(true_labels, true_predictions, mode='strict', scheme=IOB2, output_dict=True) 64 | 65 | def train(args, train_dataset, dev_dataset, model, tokenizer): 66 | """ Train the model """ 67 | train_dataloader = get_dataLoader(args, train_dataset, tokenizer, shuffle=True) 68 | dev_dataloader = get_dataLoader(args, dev_dataset, tokenizer, shuffle=False) 69 | t_total = len(train_dataloader) * args.num_train_epochs 70 | # Prepare optimizer and schedule (linear warmup and decay) 71 | no_decay = ["bias", "LayerNorm.weight"] 72 | optimizer_grouped_parameters = [ 73 | {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay}, 74 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0} 75 | ] 76 | args.warmup_steps = int(t_total * args.warmup_proportion) 77 | optimizer = AdamW( 78 | optimizer_grouped_parameters, 79 | lr=args.learning_rate, 80 | betas=(args.adam_beta1, args.adam_beta2), 81 | eps=args.adam_epsilon 82 | ) 83 | lr_scheduler = get_scheduler( 84 | 'linear', 85 | optimizer, 86 | num_warmup_steps=args.warmup_steps, 87 | num_training_steps=t_total 88 | ) 89 | # Train! 90 | logger.info("***** Running training *****") 91 | logger.info(f"Num examples - {len(train_dataset)}") 92 | logger.info(f"Num Epochs - {args.num_train_epochs}") 93 | logger.info(f"Total optimization steps - {t_total}") 94 | with open(os.path.join(args.output_dir, 'args.txt'), 'wt') as f: 95 | f.write(str(args)) 96 | 97 | total_loss = 0. 98 | best_f1 = 0. 99 | for epoch in range(args.num_train_epochs): 100 | print(f"Epoch {epoch+1}/{args.num_train_epochs}\n-------------------------------") 101 | total_loss = train_loop(args, train_dataloader, model, optimizer, lr_scheduler, epoch, total_loss) 102 | metrics = test_loop(args, dev_dataloader, model) 103 | micro_f1, macro_f1 = metrics['micro avg']['f1-score'], metrics['macro avg']['f1-score'] 104 | dev_f1 = metrics['weighted avg']['f1-score'] 105 | logger.info(f'Dev: micro_F1 - {(100*micro_f1):0.4f} macro_f1 - {(100*macro_f1):0.4f} weighted_f1 - {(100*dev_f1):0.4f}') 106 | if dev_f1 > best_f1: 107 | best_f1 = dev_f1 108 | logger.info(f'saving new weights to {args.output_dir}...\n') 109 | save_weight = f'epoch_{epoch+1}_dev_f1_{(100*dev_f1):0.4f}_weights.bin' 110 | torch.save(model.state_dict(), os.path.join(args.output_dir, save_weight)) 111 | with open(os.path.join(args.output_dir, 'dev_metrics.txt'), 'at') as f: 112 | f.write(f'epoch_{epoch+1}\n' + json.dumps(metrics, cls=NpEncoder) + '\n\n') 113 | logger.info("Done!") 114 | 115 | def predict(args, document:str, model, tokenizer): 116 | inputs = tokenizer( 117 | document, 118 | max_length=args.max_seq_length, 119 | truncation=True, 120 | return_tensors="pt", 121 | return_offsets_mapping=True 122 | ) 123 | offsets = inputs.pop('offset_mapping').squeeze(0) 124 | inputs = inputs.to(args.device) 125 | with torch.no_grad(): 126 | outputs = model(**inputs) 127 | logits = outputs[1] 128 | probabilities = torch.nn.functional.softmax(logits, dim=-1)[0].cpu().numpy().tolist() 129 | predictions = logits.argmax(dim=-1)[0].cpu().numpy().tolist() 130 | 131 | pred_label = [] 132 | idx = 1 133 | while idx < len(predictions) - 1: 134 | pred = predictions[idx] 135 | label = args.id2label[pred] 136 | if label != "O": 137 | label = label[2:] # Remove the B- or I- 138 | start, end = offsets[idx] 139 | all_scores = [probabilities[idx][pred]] 140 | # Grab all the tokens labeled with I-label 141 | while ( 142 | idx + 1 < len(predictions) - 1 and 143 | args.id2label[predictions[idx + 1]] == f"I-{label}" 144 | ): 145 | all_scores.append(probabilities[idx + 1][predictions[idx + 1]]) 146 | _, end = offsets[idx + 1] 147 | idx += 1 148 | 149 | score = np.mean(all_scores).item() 150 | start, end = start.item(), end.item() 151 | word = document[start:end] 152 | pred_label.append({ 153 | "trigger": word, 154 | "start": start, 155 | "subtype": label, 156 | "score": score 157 | }) 158 | idx += 1 159 | return pred_label 160 | 161 | def test(args, test_dataset, model, tokenizer, save_weights:list): 162 | test_dataloader = get_dataLoader(args, test_dataset, tokenizer, batch_size=1, shuffle=False) 163 | logger.info('***** Running testing *****') 164 | for save_weight in save_weights: 165 | logger.info(f'loading weights from {save_weight}...') 166 | model.load_state_dict(torch.load(os.path.join(args.output_dir, save_weight))) 167 | metrics = test_loop(args, test_dataloader, model) 168 | with open(os.path.join(args.output_dir, 'test_metrics.txt'), 'at') as f: 169 | f.write(f'{save_weight}\n{json.dumps(metrics, cls=NpEncoder)}\n\n') 170 | 171 | if __name__ == '__main__': 172 | args = parse_args() 173 | if args.do_train and os.path.exists(args.output_dir) and os.listdir(args.output_dir): 174 | raise ValueError( 175 | f'Output directory ({args.output_dir}) already exists and is not empty.') 176 | if not os.path.exists(args.output_dir): 177 | os.mkdir(args.output_dir) 178 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 179 | args.n_gpu = torch.cuda.device_count() 180 | logger.warning(f'Using {args.device} device, n_gpu: {args.n_gpu}') 181 | # Set seed 182 | seed_everything(args.seed) 183 | # Prepare task 184 | args.id2label = {0:'O'} 185 | for c in CATEGORIES: 186 | args.id2label[len(args.id2label)] = f"B-{c}" 187 | args.id2label[len(args.id2label)] = f"I-{c}" 188 | args.label2id = {v: k for k, v in args.id2label.items()} 189 | args.num_labels = len(args.id2label) 190 | # Load pretrained model and tokenizer 191 | logger.info(f'loading pretrained model and tokenizer of {args.model_type} ...') 192 | config = AutoConfig.from_pretrained( 193 | args.model_checkpoint, 194 | cache_dir=args.cache_dir 195 | ) 196 | tokenizer = AutoTokenizer.from_pretrained(args.model_checkpoint, cache_dir=args.cache_dir) 197 | model = LongformerSoftmaxForTD.from_pretrained( 198 | args.model_checkpoint, 199 | config=config, 200 | cache_dir=args.cache_dir, 201 | args=args 202 | ).to(args.device) 203 | # Training 204 | if args.do_train: 205 | train_dataset = KBPTrigger(args.train_file) 206 | dev_dataset = KBPTrigger(args.dev_file) 207 | train(args, train_dataset, dev_dataset, model, tokenizer) 208 | # Testing 209 | save_weights = [file for file in os.listdir(args.output_dir) if file.endswith('.bin')] 210 | if args.do_test: 211 | test_dataset = KBPTrigger(args.test_file) 212 | test(args, test_dataset, model, tokenizer, save_weights) 213 | # Predicting 214 | if args.do_predict: 215 | for save_weight in save_weights: 216 | logger.info(f'loading weights from {save_weight}...') 217 | model.load_state_dict(torch.load(os.path.join(args.output_dir, save_weight))) 218 | logger.info(f'predicting labels of {save_weight}...') 219 | 220 | results = [] 221 | model.eval() 222 | for sample in tqdm(test_dataset): 223 | pred_label = predict(args, sample['document'], model, tokenizer) 224 | results.append({ 225 | "doc_id": sample['id'], 226 | "document": sample['document'], 227 | "pred_label": pred_label, 228 | "true_label": sample['tags'] 229 | }) 230 | with open(os.path.join(args.output_dir, save_weight + '_test_pred_events.json'), 'wt', encoding='utf-8') as f: 231 | for exapmle_result in results: 232 | f.write(json.dumps(exapmle_result) + '\n') 233 | -------------------------------------------------------------------------------- /src/trigger_detection/run_td_softmax.sh: -------------------------------------------------------------------------------- 1 | export OUTPUT_DIR=./softmax_ce_results/ 2 | 3 | python3 run_td_softmax.py \ 4 | --output_dir=$OUTPUT_DIR \ 5 | --model_type=longformer \ 6 | --model_checkpoint=../../PT_MODELS/allenai/longformer-large-4096/ \ 7 | --train_file=../../data/train_filtered.json \ 8 | --dev_file=../../data/dev_filtered.json \ 9 | --test_file=../../data/test_filtered.json \ 10 | --max_seq_length=4096 \ 11 | --learning_rate=1e-5 \ 12 | --softmax_loss=ce \ 13 | --num_train_epochs=50 \ 14 | --batch_size=1 \ 15 | --do_train \ 16 | --warmup_proportion=0. \ 17 | --seed=42 --------------------------------------------------------------------------------