├── .gitignore
├── README.md
├── data
├── SplitSentences
│ ├── .classpath
│ ├── .project
│ ├── .settings
│ │ ├── org.eclipse.core.resources.prefs
│ │ ├── org.eclipse.core.runtime.prefs
│ │ └── org.eclipse.jdt.core.prefs
│ ├── bin
│ │ ├── test
│ │ │ ├── Main.class
│ │ │ ├── Sentence.class
│ │ │ └── SourceReader.class
│ │ └── word
│ │ │ └── WordCount.class
│ ├── readme.md
│ └── src
│ │ ├── test
│ │ ├── Main.java
│ │ ├── Sentence.java
│ │ └── SourceReader.java
│ │ └── word
│ │ └── WordCount.java
├── convert.py
├── dev.json
├── dev_filtered.json
├── kbp_sent.txt
├── kbp_vocab.txt
├── kbp_word_count.txt
├── run_info.log
├── test.json
├── test_filtered.json
├── train.json
├── train_filtered.json
└── utils.py
├── download_pt_models.sh
├── requirements.txt
└── src
├── ablation
├── arg.py
├── chunk_global_encoder.py
├── chunk_global_encoder.sh
├── data.py
├── modeling.py
├── without_global_encoder.py
└── without_global_encoder.sh
├── analysis
├── run_analysis.py
└── utils.py
├── clustering
├── arg.py
├── cluster.py
├── run_cluster.py
├── run_cluster.sh
└── utils.py
├── global_event_coref
├── analysis.py
├── arg.py
├── data.py
├── modeling.py
├── run_global_base.py
├── run_global_base.sh
├── run_global_base_with_mask.py
├── run_global_base_with_mask.sh
├── run_global_base_with_mask_topic.py
├── run_global_base_with_mask_topic.sh
├── run_global_base_with_topic.py
└── run_global_base_with_topic.sh
├── joint_model
├── arg.py
├── data.py
├── modeling.py
├── run_joint_base.py
└── run_joint_base.sh
├── local_event_coref
├── arg.py
├── data.py
├── modeling.py
├── run_local_base.py
├── run_local_base.sh
├── run_local_base_with_mask.py
├── run_local_base_with_mask.sh
├── run_local_base_with_mask_topic.py
├── run_local_base_with_mask_topic.sh
├── run_local_base_with_topic.py
└── run_local_base_with_topic.sh
├── tools.py
└── trigger_detection
├── arg.py
├── data.py
├── modeling.py
├── run_td_crf.py
├── run_td_crf.sh
├── run_td_softmax.py
└── run_td_softmax.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | cache/
7 | results/
8 | reference-coreference-scorers/
9 |
10 | data/LDC_TAC_KBP/
11 |
12 | # VS CODE
13 | .vscode/
14 |
15 | # MACOS
16 | .DS_Store
17 |
18 | # C extensions
19 | *.so
20 |
21 | # Distribution / packaging
22 | .Python
23 | build/
24 | develop-eggs/
25 | dist/
26 | downloads/
27 | eggs/
28 | .eggs/
29 | lib/
30 | lib64/
31 | parts/
32 | sdist/
33 | var/
34 | wheels/
35 | pip-wheel-metadata/
36 | share/python-wheels/
37 | *.egg-info/
38 | .installed.cfg
39 | *.egg
40 | MANIFEST
41 |
42 | # PyInstaller
43 | # Usually these files are written by a python script from a template
44 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
45 | *.manifest
46 | *.spec
47 |
48 | # Installer logs
49 | pip-log.txt
50 | pip-delete-this-directory.txt
51 |
52 | # Unit test / coverage reports
53 | htmlcov/
54 | .tox/
55 | .nox/
56 | .coverage
57 | .coverage.*
58 | .cache
59 | nosetests.xml
60 | coverage.xml
61 | *.cover
62 | *.py,cover
63 | .hypothesis/
64 | .pytest_cache/
65 |
66 | # Translations
67 | *.mo
68 | *.pot
69 |
70 | # Django stuff:
71 | local_settings.py
72 | db.sqlite3
73 | db.sqlite3-journal
74 |
75 | # Flask stuff:
76 | instance/
77 | .webassets-cache
78 |
79 | # Scrapy stuff:
80 | .scrapy
81 |
82 | # Sphinx documentation
83 | docs/_build/
84 |
85 | # PyBuilder
86 | target/
87 |
88 | # Jupyter Notebook
89 | .ipynb_checkpoints
90 |
91 | # IPython
92 | profile_default/
93 | ipython_config.py
94 |
95 | # pyenv
96 | .python-version
97 |
98 | # pipenv
99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
102 | # install all needed dependencies.
103 | #Pipfile.lock
104 |
105 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
106 | __pypackages__/
107 |
108 | # Celery stuff
109 | celerybeat-schedule
110 | celerybeat.pid
111 |
112 | # SageMath parsed files
113 | *.sage.py
114 |
115 | # Environments
116 | .env
117 | .venv
118 | env/
119 | venv/
120 | ENV/
121 | env.bak/
122 | venv.bak/
123 |
124 | # Spyder project settings
125 | .spyderproject
126 | .spyproject
127 |
128 | # Rope project settings
129 | .ropeproject
130 |
131 | # mkdocs documentation
132 | /site
133 |
134 | # mypy
135 | .mypy_cache/
136 | .dmypy.json
137 | dmypy.json
138 |
139 | # Pyre type checker
140 | .pyre/
141 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Improving Event Coreference Resolution Using Document-level and Topic-level Information
2 |
3 | This code was used in the paper:
4 |
5 | **"[Improving Event Coreference Resolution Using Document-level and Topic-level Information](https://aclanthology.org/2022.emnlp-main.454/)"**
6 | Sheng Xu, Peifeng Li and Qiaoming Zhu. EMNLP 2022.
7 |
8 | A simple pipeline model implemented in PyTorch for resolving within-document event coreference. The model was trained and evaluated on the KBP corpus.
9 |
10 | ## Set up
11 |
12 | #### Requirements
13 |
14 | Set up a Python virtual environment and run:
15 |
16 | ```bash
17 | python3 -m pip install -r requirements.txt
18 | ```
19 |
20 | #### Download the evaluation script
21 |
22 | Coreference results are obtained using official [**Reference Coreference Scorer**](https://github.com/conll/reference-coreference-scorers). This scorer reports results in terms of AVG-F, which is the unweighted average of the F-scores of four commonly used coreference evaluation metrics, namely $\text{MUC}$ ([Vilain et al., 1995](https://www.aclweb.org/anthology/M95-1005/)), $B^3$ ([Bagga and Baldwin, 1998](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.47.5848&rep=rep1&type=pdf)), $\text{CEAF}_e$ ([Luo, 2005](https://www.aclweb.org/anthology/H05-1004/)) and $\text{BLANC}$ ([Recasens and Hovy, 2011](https://www.researchgate.net/profile/Eduard-Hovy/publication/231881781_BLANC_Implementing_the_Rand_index_for_coreference_evaluation/links/553122420cf2f2a588acdc95/BLANC-Implementing-the-Rand-index-for-coreference-evaluation.pdf)).
23 |
24 | Run (from inside the repo):
25 |
26 | ```bash
27 | cd ./
28 | git clone git@github.com:conll/reference-coreference-scorers.git
29 | ```
30 |
31 | #### Download pretrained models
32 |
33 | Download the pretrained model weights (e.g. `bert-base-cased`) from Huggingface [Model Hub](https://huggingface.co/models):
34 |
35 | ```bash
36 | bash download_pt_models.sh
37 | ```
38 |
39 | **Note:** this script will download all pretrained models used in our experiment in `../PT_MODELS/`.
40 |
41 | #### Prepare the dataset
42 |
43 | This repo assumes access to the English corpora used in TAC KBP Event Nugget Detection and Coreference task (i.e., [KBP 2015](http://cairo.lti.cs.cmu.edu/kbp/2015/event/), [KBP 2016](http://cairo.lti.cs.cmu.edu/kbp/2016/event/), and [KBP 2017](http://cairo.lti.cs.cmu.edu/kbp/2017/event/)). In total, they contain 648 + 169 + 167 = 984 documents, which are either newswire articles or discussion forum threads.
44 |
45 | ```
46 | '2015': [
47 | 'LDC_TAC_KBP/LDC2015E29/data/',
48 | 'LDC_TAC_KBP/LDC2015E68/data/',
49 | 'LDC_TAC_KBP/LDC2017E02/data/2015/training/',
50 | 'LDC_TAC_KBP/LDC2017E02/data/2015/eval/'
51 | ],
52 | '2016': [
53 | 'LDC_TAC_KBP/LDC2017E02/data/2016/eval/eng/nw/',
54 | 'LDC_TAC_KBP/LDC2017E02/data/2016/eval/eng/df/'
55 | ],
56 | '2017': [
57 | 'LDC_TAC_KBP/LDC2017E54/data/eng/nw/',
58 | 'LDC_TAC_KBP/LDC2017E54/data/eng/df/'
59 | ]
60 | ```
61 |
62 | | | KBP 2015 | KBP 2016 | KBP 2017 | All |
63 | | ---------------- | :------: | :-------: | :------: | :---: |
64 | | \#Documents | 648 | 169 | 167 | 984 |
65 | | \#Event mentions | 18739 | 4155 | 4375 | 27269 |
66 | | \#Event Clusters | 11603 | 3191 | 2963 | 17757 |
67 |
68 | Following ([Lu & Ng, 2021](https://aclanthology.org/2021.emnlp-main.103/)), we select LDC2015E29, E68, E73, E94 and LDC2016E64 as train set (817 docs, 735 for training and the remaining 82 for parameter tuning), and report results on the KBP 2017 dataset.
69 |
70 | **Dataset Statistics:**
71 |
72 | | | Train | Dev | Test | All |
73 | | ---------------- | :---: | :--: | :--: | :---: |
74 | | \#Documents | 735 | 82 | 167 | 984 |
75 | | \#Event mentions | 20512 | 2382 | 4375 | 27269 |
76 | | \#Event Clusters | 13292 | 1502 | 2963 | 17757 |
77 |
78 | Then,
79 |
80 | 1. Split sentences and count verbs/entities in documents using Stanford CoreNLP (see [readme](data/SplitSentences/readme.md)), creating `kbp_sent.txt` and `kbp_word_count.txt` in the *data* folder.
81 |
82 | 2. Convert the original dataset into jsonlines format using:
83 |
84 | ```bash
85 | cd data/
86 |
87 | export DATA_DIR=
88 | python3 convert.py --kbp_data_dir $DATA_DIR
89 | ```
90 |
91 | **Note:** this script will create `train.json`、`dev.json` and `test.json` in the *data* folder, as well as `train_filtered.json`、`dev_filtered.json` and `test_filtered.json` which filter same position and overlapping event mentions.
92 |
93 | ## Training
94 |
95 | #### Trigger Detection
96 |
97 | Train a sequence labeling model for Trigger Detection using the BIO tagging schema (Run with `--do_train`):
98 |
99 | ```bash
100 | cd src/trigger_detection/
101 |
102 | export OUTPUT_DIR=./softmax_ce_results/
103 |
104 | python3 run_td_softmax.py \
105 | --output_dir=$OUTPUT_DIR \
106 | --model_type=longformer \
107 | --model_checkpoint=../../PT_MODELS/allenai/longformer-large-4096/ \
108 | --train_file=../../data/train_filtered.json \
109 | --dev_file=../../data/dev_filtered.json \
110 | --test_file=../../data/test_filtered.json \
111 | --max_seq_length=4096 \
112 | --learning_rate=1e-5 \
113 | --softmax_loss=ce \
114 | --num_train_epochs=50 \
115 | --batch_size=1 \
116 | --do_train \
117 | --warmup_proportion=0. \
118 | --seed=42
119 | ```
120 |
121 | After training, the model weights and the evaluation results on **Dev** set would be saved in `$OUTPUT_DIR`.
122 |
123 | #### Event Coreference
124 |
125 | Train the full version of our event coreference model using (Run with `--do_train`):
126 |
127 | ```bash
128 | cd src/global_event_coref/
129 |
130 | export OUTPUT_DIR=./MaskTopic_M-multi-cosine_results/
131 |
132 | python3 run_global_base_with_mask_topic.py \
133 | --output_dir=$OUTPUT_DIR \
134 | --model_type=longformer \
135 | --model_checkpoint=../../../PT_MODELS/allenai/longformer-large-4096/ \
136 | --mention_encoder_type=bert \
137 | --mention_encoder_checkpoint=../../../PT_MODELS/bert-base-cased/ \
138 | --topic_model=vmf \
139 | --topic_dim=32 \
140 | --topic_inter_map=64 \
141 | --train_file=../../data/train_filtered.json \
142 | --dev_file=../../data/dev_filtered.json \
143 | --test_file=../../data/test_filtered.json \
144 | --max_seq_length=4096 \
145 | --max_mention_length=256 \
146 | --learning_rate=1e-5 \
147 | --matching_style=multi_cosine \
148 | --softmax_loss=ce \
149 | --num_train_epochs=50 \
150 | --batch_size=1 \
151 | --do_train \
152 | --warmup_proportion=0. \
153 | --seed=42
154 | ```
155 |
156 | After training, the model weights and evaluation results on **Dev** set would be saved in `$OUTPUT_DIR`.
157 |
158 | ## Evaluation
159 |
160 | #### Trigger Detection
161 |
162 | Run *run_td_softmax.py* with `--do_test`:
163 |
164 | ```bash
165 | cd src/trigger_detection/
166 |
167 | export OUTPUT_DIR=./softmax_ce_results/
168 |
169 | python3 run_td_softmax.py \
170 | --output_dir=$OUTPUT_DIR \
171 | --model_type=longformer \
172 | --model_checkpoint=../../PT_MODELS/allenai/longformer-large-4096/ \
173 | --train_file=../../data/train_filtered.json \
174 | --dev_file=../../data/dev_filtered.json \
175 | --test_file=../../data/test_filtered.json \
176 | --max_seq_length=4096 \
177 | --learning_rate=1e-5 \
178 | --softmax_loss=ce \
179 | --num_train_epochs=50 \
180 | --batch_size=1 \
181 | --do_test \
182 | --warmup_proportion=0. \
183 | --seed=42
184 | ```
185 |
186 | After evaluation, the evaluation results on **Test** set would be saved in `$OUTPUT_DIR`. Use `--do_predict` parameter to predict subtype labels. The predicted results, i.e., `XXX_test_pred_events.json`, would be saved in `$OUTPUT_DIR`.
187 |
188 | #### Event Coreference
189 |
190 | Run *run_global_base_with_mask_topic.py* with `--do_test`:
191 |
192 | ```bash
193 | cd src/global_event_coref/
194 |
195 | export OUTPUT_DIR=./MaskTopic_M-multi-cosine_results/
196 |
197 | python3 run_global_base_with_mask_topic.py \
198 | --output_dir=$OUTPUT_DIR \
199 | --model_type=longformer \
200 | --model_checkpoint=../../../PT_MODELS/allenai/longformer-large-4096/ \
201 | --mention_encoder_type=bert \
202 | --mention_encoder_checkpoint=../../../PT_MODELS/bert-base-cased/ \
203 | --topic_model=vmf \
204 | --topic_dim=32 \
205 | --topic_inter_map=64 \
206 | --train_file=../../data/train_filtered.json \
207 | --dev_file=../../data/dev_filtered.json \
208 | --test_file=../../data/test_filtered.json \
209 | --max_seq_length=4096 \
210 | --max_mention_length=256 \
211 | --learning_rate=1e-5 \
212 | --matching_style=multi_cosine \
213 | --softmax_loss=ce \
214 | --num_train_epochs=50 \
215 | --batch_size=1 \
216 | --do_test \
217 | --warmup_proportion=0. \
218 | --seed=42
219 | ```
220 |
221 | After evaluation, the evaluation results on **Test** set would be saved in `$OUTPUT_DIR`. Use `--do_predict` parameter to predict coreferences for event mention pairs. The predicted results, i.e., `XXX_test_pred_corefs.json`, would be saved in `$OUTPUT_DIR`.
222 |
223 | #### Clustering
224 |
225 | Create the final event clusters using predicted pairwise results:
226 |
227 | ```bash
228 | cd src/clustering
229 |
230 | export OUTPUT_DIR=./TEMP/
231 |
232 | python3 run_cluster.py \
233 | --output_dir=$OUTPUT_DIR \
234 | --test_golden_filepath=../../data/test.json \
235 | --test_pred_filepath=../../data/XXX_weights.bin_test_pred_corefs.json \
236 | --golden_conll_filename=gold_test.conll \
237 | --pred_conll_filename=pred_test.conll \
238 | --do_evaluate
239 | ```
240 |
241 | ## Results
242 |
243 | #### Download Final Model
244 |
245 | You can download the final Trigger Detection & Event Coreference models at:
246 |
247 | [https://drive.google.com/drive/folders/182jll9UZ8yqQ93Dev92XDI0v2jhN7wcw?usp=sharing](https://drive.google.com/drive/folders/182jll9UZ8yqQ93Dev92XDI0v2jhN7wcw?usp=sharing)
248 |
249 | #### Trigger Detection
250 |
251 | | Model | Micro (P / R / F1) | Macro (P / R / F1) |
252 | | ------------------------------------------------------------ | :----------------: | :----------------: |
253 | | [(Lu & Ng, 2021)](https://aclanthology.org/2021.emnlp-main.103/) | 71.6 / 58.7 / 64.5 | - / - / - |
254 | | Longformer | 63.0 / 58.1 / 60.4 | 65.2 / 57.7 / 59.2 |
255 | | Longformer+CRF | 64.8 / 54.6 / 59.2 | 65.9 / 55.2 / 58.1 |
256 |
257 | #### Classical Pairwise Models
258 |
259 | | Model | Pairwise | MUC | B3 | CEA | BLA | AVG |
260 | | --------------------------- | :----------------: | :--: | :--: | :--: | :--: | :--: |
261 | | BERT-large[Prod] | 62.3 / 49.3 / 55.0 | 36.5 | 54.4 | 55.8 | 37.3 | 46.0 |
262 | | RoBERTa-large[Prod] | 64.6 / 44.0 / 52.4 | 36.0 | 54.8 | 55.6 | 37.3 | 45.9 |
263 | | BERT-large[Prod] + Local | 69.0 / 45.5 / 54.8 | 37.6 | 55.1 | 57.1 | 38.5 | 47.1 |
264 | | RoBERTa-large[Prod] + Local | 71.7 / 49.9 / 58.9 | 39.0 | 55.8 | 58.0 | 39.6 | 48.1 |
265 |
266 | #### Pairwise & Chunk Variants
267 |
268 | Replace Global Mention Encoder in our model with pairwise (sentence-level) encoder or chunk (segment-level) encoder.
269 |
270 | | Model | Pairwise | MUC | B3 | CEA | BLA | AVG |
271 | | ---------------------- | :----------------: | :--: | :--: | :--: | :--: | :--: |
272 | | BERT-base[Pairwise] | 64.0 / 39.8 / 49.0 | 35.3 | 54.4 | 55.8 | 36.6 | 45.5 |
273 | | RoBERTa-base[Pairwise] | 59.9 / 55.6 / 57.7 | 39.0 | 54.3 | 56.4 | 38.6 | 47.1 |
274 | | BERT-base[Chunk] | 59.7 / 50.6 / 54.7 | 38.4 | 54.9 | 55.4 | 37.9 | 46.7 |
275 | | RoBERTa-base[Chunk] | 64.0 / 51.3 / 56.9 | 39.6 | 55.2 | 56.9 | 38.5 | 47.6 |
276 |
277 | #### Our Model
278 |
279 | | Model | Pairwise | MUC | B3 | CEA | BLA | AVG |
280 | | ------------------------------------------------------------ | :----------------: | :--: | :--: | :--: | :--: | :--: |
281 | | [(Lu & Ng, 2021)](https://aclanthology.org/2021.emnlp-main.103/) | - | 45.2 | 54.7 | 53.8 | 38.2 | 48.0 |
282 | | Global | 74.7 / 63.2 / 68.4 | 45.4 | 57.3 | 58.7 | 42.2 | 50.9 |
283 | | + Local | 72.4 / 63.3 / 67.6 | 45.8 | 57.5 | 59.1 | 42.1 | 51.1 |
284 | | + Local & Topic | 72.0 / 64.4 / 68.0 | 46.2 | 57.4 | 59.0 | 42.0 | 51.2 |
285 |
286 | #### Variants using different tensor matching
287 |
288 | | Model | Pairwise | MUC | B3 | CEA | BLA | AVG |
289 | | ------------------ | :----------------: | :--: | :--: | :--: | :--: | :--: |
290 | | Base | 37.5 / 48.0 / 42.1 | 36.7 | 54.9 | 55.3 | 34.7 | 45.4 |
291 | | Base+Prod | 71.2 / 64.0 / 67.4 | 45.4 | 57.0 | 58.6 | 41.2 | 50.5 |
292 | | Base+Prod+Cos | 72.0 / 64.4 / 68.0 | 46.2 | 57.4 | 59.0 | 42.0 | 51.2 |
293 | | Base+Prod+Diff | 70.3 / 67.1 / 68.7 | 45.0 | 56.7 | 58.9 | 41.4 | 50.5 |
294 | | Base+Prod+Diff+Cos | 69.5 / 65.9 / 67.6 | 44.4 | 56.5 | 58.6 | 41.2 | 50.2 |
295 |
296 | ## Contact info
297 |
298 | Contact [Sheng Xu](https://github.com/jsksxs360) at *[sxu@stu.suda.edu.cn](mailto:sxu@stu.suda.edu.cn)* for questions about this repository.
299 |
300 | ```
301 | @inproceedings{xu-etal-2022-improving,
302 | title = "Improving Event Coreference Resolution Using Document-level and Topic-level Information",
303 | author = "Xu, Sheng and
304 | Li, Peifeng and
305 | Zhu, Qiaoming",
306 | booktitle = "Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing",
307 | month = dec,
308 | year = "2022",
309 | address = "Abu Dhabi, United Arab Emirates",
310 | publisher = "Association for Computational Linguistics",
311 | url = "https://aclanthology.org/2022.emnlp-main.454",
312 | pages = "6765--6775"
313 | }
314 | ```
315 |
--------------------------------------------------------------------------------
/data/SplitSentences/.classpath:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/data/SplitSentences/.project:
--------------------------------------------------------------------------------
1 |
2 |
3 | SplitSentences
4 |
5 |
6 |
7 |
8 |
9 | org.eclipse.jdt.core.javabuilder
10 |
11 |
12 |
13 |
14 |
15 | org.eclipse.jdt.core.javanature
16 |
17 |
18 |
--------------------------------------------------------------------------------
/data/SplitSentences/.settings/org.eclipse.core.resources.prefs:
--------------------------------------------------------------------------------
1 | eclipse.preferences.version=1
2 | encoding/=UTF-8
3 |
--------------------------------------------------------------------------------
/data/SplitSentences/.settings/org.eclipse.core.runtime.prefs:
--------------------------------------------------------------------------------
1 | eclipse.preferences.version=1
2 | line.separator=\n
3 |
--------------------------------------------------------------------------------
/data/SplitSentences/.settings/org.eclipse.jdt.core.prefs:
--------------------------------------------------------------------------------
1 | eclipse.preferences.version=1
2 | org.eclipse.jdt.core.compiler.codegen.inlineJsrBytecode=enabled
3 | org.eclipse.jdt.core.compiler.codegen.targetPlatform=1.8
4 | org.eclipse.jdt.core.compiler.codegen.unusedLocal=preserve
5 | org.eclipse.jdt.core.compiler.compliance=1.8
6 | org.eclipse.jdt.core.compiler.debug.lineNumber=generate
7 | org.eclipse.jdt.core.compiler.debug.localVariable=generate
8 | org.eclipse.jdt.core.compiler.debug.sourceFile=generate
9 | org.eclipse.jdt.core.compiler.problem.assertIdentifier=error
10 | org.eclipse.jdt.core.compiler.problem.enumIdentifier=error
11 | org.eclipse.jdt.core.compiler.source=1.8
12 |
--------------------------------------------------------------------------------
/data/SplitSentences/bin/test/Main.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsksxs360/event-coref-emnlp2022/f6b82beee0558b187e63f50cf3064fc9ef9d39c8/data/SplitSentences/bin/test/Main.class
--------------------------------------------------------------------------------
/data/SplitSentences/bin/test/Sentence.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsksxs360/event-coref-emnlp2022/f6b82beee0558b187e63f50cf3064fc9ef9d39c8/data/SplitSentences/bin/test/Sentence.class
--------------------------------------------------------------------------------
/data/SplitSentences/bin/test/SourceReader.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsksxs360/event-coref-emnlp2022/f6b82beee0558b187e63f50cf3064fc9ef9d39c8/data/SplitSentences/bin/test/SourceReader.class
--------------------------------------------------------------------------------
/data/SplitSentences/bin/word/WordCount.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsksxs360/event-coref-emnlp2022/f6b82beee0558b187e63f50cf3064fc9ef9d39c8/data/SplitSentences/bin/word/WordCount.class
--------------------------------------------------------------------------------
/data/SplitSentences/readme.md:
--------------------------------------------------------------------------------
1 | ### Split sentences and Count verb/entity numbers
2 |
3 | 1. Create *SplitSentences/lib* folder if not exist.
4 | 2. Download **CoreNLP X.X.X** and **English (KBP)** model jar from [CoreNLP](https://stanfordnlp.github.io/CoreNLP/index.html#quickstart).
5 | Unzip **CoreNLP X.X.X**, move `slf4j-api.jar`, `slf4j-simple.jar`, `stanford-corenlp-x.x.x.jar`, `stanford-corenlp-x.x.x-models.jar`, and ``stanford-corenlp-models-english-kbp.jar`` to the *SplitSentences/lib* folder.
6 | 3. Download [**Gson**](http://www.java2s.com/example/jar/g/gson-index.html) jar and move `gson-x.x.x.jar` to the *SplitSentences/lib* folder.
7 | 4. Run `Main.java` and `WordCount.java`.
8 |
9 |
--------------------------------------------------------------------------------
/data/SplitSentences/src/test/Main.java:
--------------------------------------------------------------------------------
1 | package test;
2 |
3 | import java.io.BufferedWriter;
4 | import java.io.FileWriter;
5 | import java.io.IOException;
6 | import java.util.Arrays;
7 | import java.util.LinkedList;
8 | import java.util.List;
9 | import java.util.regex.Matcher;
10 | import java.util.regex.Pattern;
11 |
12 | public class Main {
13 |
14 | public static void main(String[] args) {
15 |
16 | String LDC2015E29 = "../LDC_TAC_KBP/LDC2015E29/data/source/mpdfxml/";
17 | String LDC2015E68 = "../LDC_TAC_KBP/LDC2015E68/data/source/";
18 |
19 | String KBP2015Train = "../LDC_TAC_KBP/LDC2017E02/data/2015/training/source/";
20 | String KBP2015Eval = "../LDC_TAC_KBP/LDC2017E02/data/2015/eval/source/";
21 |
22 | String KBP2016EvalNW = "../LDC_TAC_KBP/LDC2017E02/data/2016/eval/eng/nw/source/";
23 | String KBP2016EvalDF = "../LDC_TAC_KBP/LDC2017E02/data/2016/eval/eng/df/source/";
24 |
25 | String KBP2017EvalNW = "../LDC_TAC_KBP/LDC2017E54/data/eng/nw/source/";
26 | String KBP2017EvalDF = "../LDC_TAC_KBP/LDC2017E54/data/eng/df/source/";
27 |
28 | SourceReader reader = new SourceReader();
29 | List KBPSents = new LinkedList<>();
30 | try {
31 | // LDC2015E29
32 | List LDC2015E29Sents = reader.readSourceFolder(LDC2015E29);
33 | System.out.println("LDC2015E29: " + LDC2015E29Sents.size());
34 | KBPSents.addAll(LDC2015E29Sents);
35 | // LDC2015E68
36 | List LDC2015E68Sents = reader.readSourceFolder(LDC2015E68);
37 | System.out.println("LDC2015E68: " + LDC2015E68Sents.size());
38 | KBPSents.addAll(LDC2015E68Sents);
39 | // KBP2015
40 | List LDC2015TrainSents = reader.readSourceFolder(KBP2015Train);
41 | List LDC2015EvalSents = reader.readSourceFolder(KBP2015Eval);
42 | System.out.println("LDC2015: " + (LDC2015TrainSents.size() + LDC2015EvalSents.size()));
43 | KBPSents.addAll(LDC2015TrainSents);
44 | KBPSents.addAll(LDC2015EvalSents);
45 | // KBP 2016
46 | List KBP2016EvalNWSents = reader.readSourceFolder(KBP2016EvalNW, false);
47 | List KBP2016EvalDFSents = reader.readSourceFolder(KBP2016EvalDF, true);
48 | System.out.println("KBP2016: " + (KBP2016EvalNWSents.size() + KBP2016EvalDFSents.size()));
49 | KBPSents.addAll(KBP2016EvalNWSents);
50 | KBPSents.addAll(KBP2016EvalDFSents);
51 | // KBP 2017
52 | List KBP2017EvalNWSents = reader.readSourceFolder(KBP2017EvalNW, false);
53 | List KBP2017EvalDFSents = reader.readSourceFolder(KBP2017EvalDF, true);
54 | System.out.println("KBP2017: " + (KBP2017EvalNWSents.size() + KBP2017EvalDFSents.size()));
55 | KBPSents.addAll(KBP2017EvalNWSents);
56 | KBPSents.addAll(KBP2017EvalDFSents);
57 | } catch (Exception e) {
58 | e.printStackTrace();
59 | }
60 | try {
61 | saveFile("../kbp_sent.txt", KBPSents);
62 | } catch (IOException e) {
63 | e.printStackTrace();
64 | }
65 | }
66 |
67 | public static void saveFile(String filename, List sents) throws IOException {
68 | BufferedWriter writer = new BufferedWriter(new FileWriter(filename));
69 | for (Sentence sent : sents) {
70 | String text = sent.text.replace("\t", " ");
71 | if (isContainChinese(text) || text.startsWith("http") || text.startsWith("www.") || filter(text)) {
72 | continue;
73 | }
74 | writer.write(sent.filename + "\t" + sent.start + "\t" + text + "\n");
75 | }
76 | writer.close();
77 | }
78 |
79 | public static boolean filter(String str) {
80 | List stopwords = Arrays.asList("P.S.", "PS", "snip",
81 | "&", "<", ">", " ", """,
82 | "#", "*", ".", "/", "year", "day", "month", "Â", "-", "[", "]",
83 | "!", "?", ",", ";", "(", ")", ":", "~", "_",
84 | "cof", "sigh", "shrug", "and", "or", "done", "URL");
85 | for (String w : stopwords) {
86 | str = str.replace(w, " ");
87 | }
88 | Pattern p = Pattern.compile("[0-9]");
89 | Matcher matcher = p.matcher(str);
90 | str = matcher.replaceAll(" ");
91 | if (str.trim().isEmpty() || str.trim().length() == 1) return true;
92 | return false;
93 | }
94 |
95 | public static boolean isContainChinese(String str) {
96 | Pattern p = Pattern.compile("[\u4E00-\u9FA5]");
97 | Matcher m = p.matcher(str);
98 | if (m.find()) {
99 | return true;
100 | }
101 | return false;
102 | }
103 |
104 | }
105 |
--------------------------------------------------------------------------------
/data/SplitSentences/src/test/Sentence.java:
--------------------------------------------------------------------------------
1 | package test;
2 |
3 | public class Sentence {
4 | public String filename;
5 | public String text;
6 | public int start;
7 |
8 | public Sentence(String filename, String text, int start) {
9 | this.filename = filename;
10 | this.text = text;
11 | this.start = start;
12 | }
13 | }
14 |
--------------------------------------------------------------------------------
/data/SplitSentences/src/test/SourceReader.java:
--------------------------------------------------------------------------------
1 | package test;
2 |
3 | import java.io.BufferedReader;
4 | import java.io.File;
5 | import java.io.FileReader;
6 | import java.util.Arrays;
7 | import java.util.LinkedList;
8 | import java.util.List;
9 | import java.util.Properties;
10 | import java.util.regex.Matcher;
11 | import java.util.regex.Pattern;
12 |
13 | import edu.stanford.nlp.pipeline.CoreDocument;
14 | import edu.stanford.nlp.pipeline.CoreSentence;
15 | import edu.stanford.nlp.pipeline.StanfordCoreNLP;
16 | import edu.stanford.nlp.util.StringUtils;
17 |
18 | public class SourceReader {
19 |
20 | private StanfordCoreNLP pipeline;
21 | private List newsStart = Arrays.asList(new String[]{"AFP", "APW", "CNA", "NYT", "WPB", "XIN"});
22 |
23 | public SourceReader() {
24 | Properties props = new Properties();
25 | props.setProperty("annotators", "tokenize,ssplit");
26 | pipeline = new StanfordCoreNLP(props);
27 | }
28 |
29 | public List readSourceFolder(String folderPath, boolean isDF) throws Exception {
30 | File folder = new File(folderPath);
31 | List results = new LinkedList<>();
32 | for (String file : folder.list()) {
33 | results.addAll(readSourceFile(folderPath + file, isDF));
34 | }
35 | return results;
36 | }
37 |
38 | public List readSourceFolder(String folderPath) throws Exception {
39 | File folder = new File(folderPath);
40 | List results = new LinkedList<>();
41 | for (String file : folder.list()) {
42 | if (newsStart.contains(file.substring(0, 3))) { // News
43 | results.addAll(readSourceFile(folderPath + file, false));
44 | } else { // Forum
45 | results.addAll(readSourceFile(folderPath + file, true));
46 | }
47 | }
48 | return results;
49 | }
50 |
51 | public List readSourceFile(String filePath, boolean isDF) throws Exception {
52 | if (isDF) { // Forum
53 | return forumArticleReader(filePath, this.pipeline);
54 | } else { // News
55 | return newsArticleReader(filePath, this.pipeline);
56 | }
57 | }
58 |
59 | private static List forumArticleReader(String filePath, StanfordCoreNLP model) throws Exception {
60 | BufferedReader br = new BufferedReader(new FileReader(filePath));
61 | String filename = new File(filePath).getName();
62 | List filters = Arrays.asList(".txt", ".xml", ".mpdf", ".cmp");
63 | for (String w : filters) {
64 | filename = filename.replace(w, "");
65 | }
66 | String line;
67 | int start = 0;
68 | List results = new LinkedList<>();
69 | while ((line = br.readLine()) != null) {
70 | int length = line.length() + 1;
71 | line = line.trim();
72 | if (line.startsWith("") || line.startsWith("")) {
73 | start += length;
74 | continue;
75 | }
76 | List sents = splitSentences(filename, line, start, model);
77 | results.addAll(sents);
78 | start += length;
79 | }
80 | br.close();
81 | return results;
82 | }
83 |
84 | private static List newsArticleReader(String filePath, StanfordCoreNLP model) throws Exception {
85 | BufferedReader br = new BufferedReader(new FileReader(filePath));
86 | String filename = new File(filePath).getName();
87 | List filters = Arrays.asList(".txt", ".xml", ".mpdf", ".cmp");
88 | for (String w : filters) {
89 | filename = filename.replace(w, "");
90 | }
91 | String line;
92 | String Flag = "";
93 | String text = "";
94 | int start = 0;
95 | List results = new LinkedList<>();
96 | while ((line = br.readLine()) != null) {
97 | int length = line.length() + 1;
98 | if (line.trim().equals("")) {
99 | Flag = "TEXT";
100 | start += length;
101 | continue;
102 | } else if (line.trim().equals("") || line.trim().equals("") || line.trim().equals("")) {
103 | Flag = "PARA";
104 | start += length;
105 | continue;
106 | } else if (line.trim().equals("
") || line.trim().equals("
") || line.trim().equals("")) {
107 | Flag = "";
108 | List sentences = splitSentences(filename, text, start, model);
109 | results.addAll(sentences);
110 | start += text.length() + length;
111 | text = "";
112 | continue;
113 | } else if (line.trim().equals("")) {
114 | Flag = "";
115 | start += length;
116 | text = "";
117 | continue;
118 | }
119 | if (Flag.equals("PARA")) {
120 | text += line + " ";
121 | continue;
122 | } else if (Flag.equals("TEXT")) {
123 | List sentences = splitSentences(filename, line, start, model);
124 | results.addAll(sentences);
125 | start += length;
126 | continue;
127 | }
128 | start += length;
129 | }
130 | br.close();
131 | return results;
132 | }
133 |
134 | private static List splitSentences(String filename, String text, int start, StanfordCoreNLP pipeline) throws Exception {
135 | if (text.contains("<")) { // html file
136 | Pattern p_html = Pattern.compile("<[^>]+>", Pattern.CASE_INSENSITIVE);
137 | Matcher m_html = p_html.matcher(text);
138 | StringBuffer sb = new StringBuffer();
139 | while (m_html.find()) {
140 | m_html.appendReplacement(sb, StringUtils.repeat(" ", m_html.group().length()));
141 | }
142 | m_html.appendTail(sb);
143 | text = sb.toString();
144 | int count = 0;
145 | if (text.startsWith(" ")) {
146 | for (int i = 0; i < text.length(); i++) {
147 | if (text.charAt(i) != ' ') { break; }
148 | count += 1;
149 | }
150 | }
151 | text = text.trim();
152 | start += count;
153 | }
154 | // split sentence
155 | CoreDocument doc = new CoreDocument(text);
156 | pipeline.annotate(doc);
157 | List results = new LinkedList<>();
158 | for (CoreSentence sent : doc.sentences()) {
159 | Integer sentOffset = sent.charOffsets().first;
160 | String sentText = sent.text();
161 | if (sentText.isEmpty() || sentText.length() < 3) continue;
162 | results.add(new Sentence(filename, sentText, start + sentOffset));
163 | }
164 | return results;
165 | }
166 | }
167 |
--------------------------------------------------------------------------------
/data/SplitSentences/src/word/WordCount.java:
--------------------------------------------------------------------------------
1 | package word;
2 |
3 | import java.io.BufferedReader;
4 | import java.io.BufferedWriter;
5 | import java.io.FileNotFoundException;
6 | import java.io.FileReader;
7 | import java.io.FileWriter;
8 | import java.io.IOException;
9 | import java.util.Arrays;
10 | import java.util.HashMap;
11 | import java.util.List;
12 | import java.util.Map;
13 | import java.util.Properties;
14 |
15 | import com.google.gson.Gson;
16 |
17 | import edu.stanford.nlp.ling.CoreLabel;
18 | import edu.stanford.nlp.pipeline.CoreDocument;
19 | import edu.stanford.nlp.pipeline.CoreEntityMention;
20 | import edu.stanford.nlp.pipeline.StanfordCoreNLP;
21 |
22 | public class WordCount {
23 |
24 | private StanfordCoreNLP pipeline;
25 | List entityType = Arrays.asList("PERSON", "LOCATION", "ORGANIZATION");
26 | List stopwords = Arrays.asList(
27 | "a", "an", "and", "are", "as", "at", "be", "but", "by",
28 | "for", "if", "in", "into", "is", "it", "been",
29 | "no", "not", "of", "on", "or", "such",
30 | "that", "the", "their", "then", "there", "these",
31 | "they", "this", "to", "was", "will", "with",
32 | "he", "she", "his", "her", "were", "do"
33 | );
34 | public WordCount() {
35 | Properties props = new Properties();
36 | props.setProperty("annotators", "tokenize,ssplit,pos,lemma,ner");
37 | props.setProperty("ner.applyFineGrained", "false");
38 | props.setProperty("ner.applyNumericClassifiers", "false");
39 | pipeline = new StanfordCoreNLP(props);
40 | }
41 |
42 | public Map getVerbEntity(String document) {
43 | CoreDocument doc = new CoreDocument(document);
44 | Map wordStatistic = new HashMap();
45 | this.pipeline.annotate(doc);
46 | for (CoreLabel tok : doc.tokens()) {
47 | String word = tok.word().toLowerCase();
48 | if (this.stopwords.contains(word)) continue;
49 | if (tok.tag().startsWith("VB")) {
50 | wordStatistic.put(word, wordStatistic.getOrDefault(word, 0) + 1);
51 | }
52 | }
53 | for (CoreEntityMention em : doc.entityMentions()) {
54 | String entity = em.text().toLowerCase();
55 | if (this.stopwords.contains(entity) || !this.entityType.contains(em.entityType())) continue;
56 | wordStatistic.put(entity, wordStatistic.getOrDefault(entity, 0) + 1);
57 | }
58 | return wordStatistic;
59 | }
60 |
61 | public static void main(String[] args) throws IOException {
62 | String kbp_sent_filePath = "../kbp_sent.txt";
63 | BufferedReader br = new BufferedReader(new FileReader(kbp_sent_filePath));
64 | String line;
65 | Map kbp_documents = new HashMap();
66 | while ((line = br.readLine()) != null) {
67 | String[] items = line.trim().split("\t");
68 | if (kbp_documents.containsKey(items[0])) {
69 | kbp_documents.replace(items[0], kbp_documents.get(items[0]) + " " + items[2]);
70 | } else {
71 | kbp_documents.put(items[0], items[2]);
72 | }
73 | }
74 | br.close();
75 | WordCount wc = new WordCount();
76 | Gson gson = new Gson();
77 | BufferedWriter bw = new BufferedWriter(new FileWriter("../kbp_word_count.txt"));
78 | for (Map.Entry entry : kbp_documents.entrySet()) {
79 | System.out.println(entry.getKey());
80 | String countStr = gson.toJson(wc.getVerbEntity(entry.getValue()));
81 | bw.write(entry.getKey() + "\t" + countStr + "\n");
82 | }
83 | bw.close();
84 | }
85 |
86 | }
87 |
--------------------------------------------------------------------------------
/data/convert.py:
--------------------------------------------------------------------------------
1 | import collections
2 | from collections import namedtuple
3 | import xml.etree.ElementTree as ET
4 | import os
5 | import re
6 | from typing import Dict, List, Tuple
7 | import logging
8 | import json
9 | import numpy as np
10 | from itertools import combinations
11 | import argparse
12 | from utils import print_data_statistic, filter_events, check_event_conflict
13 |
14 | parser = argparse.ArgumentParser()
15 |
16 | parser.add_argument("--kbp_data_dir", default='LDC_TAC_KBP/', type=str)
17 | parser.add_argument("--sent_data_dir", default='./', type=str)
18 | args = parser.parse_args()
19 |
20 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
21 | datefmt='%Y/%m/%d %H:%M:%S',
22 | level=logging.INFO)
23 |
24 | logger = logging.getLogger("Convert")
25 |
26 | SENT_FILE = 'kbp_sent.txt'
27 | DATA_DIRS = {
28 | '2015': [
29 | 'LDC2015E29/data/ere/mpdfxml',
30 | 'LDC2015E68/data/ere',
31 | 'LDC2017E02/data/2015/training/event_hopper',
32 | 'LDC2017E02/data/2015/eval/hopper'
33 | ],
34 | '2016': [
35 | 'LDC2017E02/data/2016/eval/eng/nw/ere',
36 | 'LDC2017E02/data/2016/eval/eng/df/ere'
37 | ],
38 | '2017': [
39 | 'LDC2017E54/data/eng/nw/ere',
40 | 'LDC2017E54/data/eng/df/ere'
41 | ]
42 | }
43 |
44 | Sentence = namedtuple("Sentence", ["start", "text"])
45 | Filename = namedtuple("Filename", ["doc_id", "file_path"])
46 |
47 | def get_KBP_sents(sent_file_path:str) -> Dict[str, List[Sentence]]:
48 | '''get sentences in the KBP dataset
49 | # Returns:
50 | - sentence dictionary: {filename: [Sentence]}
51 | '''
52 | sent_dic = collections.defaultdict(list)
53 | with open(sent_file_path, 'rt', encoding='utf-8') as sents:
54 | for line in sents:
55 | doc_id, start, text = line.strip().split('\t')
56 | sent_dic[doc_id].append(Sentence(int(start), text))
57 | for sents in sent_dic.values():
58 | sents.sort(key=lambda x:x.start)
59 | return sent_dic
60 |
61 | def get_KBP_filenames(version:str) -> List[Filename]:
62 | '''get KBP filenames
63 | # Args:
64 | - version: 2015 / 2016 / 2017
65 | # Return:
66 | - filename list: [Filename]
67 | '''
68 | assert version in ['2015', '2016', '2017']
69 | filename_list = []
70 | for folder in DATA_DIRS[version]:
71 | filename_list += [
72 | Filename(
73 | re.sub('\.event_hoppers\.xml|\.rich_ere\.xml', '', filename),
74 | os.path.join(folder, filename)
75 | ) for filename in os.listdir(os.path.join(args.kbp_data_dir, folder))
76 | ]
77 | return filename_list
78 |
79 | def create_new_document(sent_list:List[Sentence]) -> str:
80 | '''create new source document
81 | '''
82 | document = ''
83 | end = 0
84 | for sent in sent_list:
85 | assert sent.start >= end
86 | document += ' ' * (sent.start - end)
87 | document += sent.text
88 | end = sent.start + len(sent.text)
89 | for sent in sent_list: # check
90 | assert document[sent.start:sent.start+len(sent.text)] == sent.text
91 | return document
92 |
93 | def find_event_sent(doc_id, event_start, trigger, sent_list) -> Tuple[int, int]:
94 | '''find out which sentence the event come from
95 | '''
96 | for idx, sent in enumerate(sent_list):
97 | s_start, s_end = sent.start, sent.start + len(sent.text) - 1
98 | if s_start <= event_start <= s_end:
99 | e_s_start = event_start - s_start
100 | assert sent.text[e_s_start:e_s_start+len(trigger)] == trigger
101 | return idx, event_start - s_start
102 | print(doc_id)
103 | print(event_start, trigger, '\n')
104 | for sent in sent_list:
105 | print(sent.start, sent.start + len(sent.text) - 1)
106 | return None
107 |
108 | def update_trigger(text, trigger, offset):
109 | punc_set = set('#$%&+=@.,;!?*\\~\'\n\r\t()[]|/’-:{<>}、"。,?“”')
110 | new_trigger = trigger
111 | if offset + len(trigger) < len(text) and text[offset + len(trigger)] != ' ' and text[offset + len(trigger)] not in punc_set:
112 | for c in text[offset + len(trigger):]:
113 | if c == ' ' or c in punc_set:
114 | break
115 | new_trigger += c
116 | new_trigger = new_trigger.strip('\n\r\t')
117 | new_trigger = new_trigger.strip(u'\x94')
118 | if new_trigger != trigger:
119 | logger.warning(f'update: [{trigger}]({len(trigger)}) - [{new_trigger}]({len(new_trigger)})')
120 | return new_trigger
121 |
122 | def xml_parser(file_path:str, sent_list:List[Sentence]) -> Dict:
123 | '''KBP datafile XML parser
124 | # Args:
125 | - file_path: xml file path
126 | - sent_list: Sentences of file
127 | '''
128 | tree = ET.ElementTree(file=file_path)
129 | doc_id = re.sub('\.event_hoppers\.xml|\.rich_ere\.xml', '', os.path.split(file_path)[1])
130 | document = create_new_document(sent_list)
131 | sentence_list = [{'start': sent.start, 'text': sent.text} for sent in sent_list]
132 | event_list = []
133 | cluster_list = []
134 | for hopper in tree.iter(tag='hopper'):
135 | h_id = hopper.attrib['id'] # hopper id
136 | h_events = []
137 | for event in hopper.iter(tag='event_mention'):
138 | att = event.attrib
139 | e_id = att['id']
140 | e_type, e_subtype, e_realis = att['type'], att['subtype'], att['realis']
141 | e_trigger = event.find('trigger').text.strip()
142 | e_start = int(event.find('trigger').attrib['offset'])
143 | e_s_index, e_s_start = find_event_sent(doc_id, e_start, e_trigger, sent_list)
144 | e_trigger = update_trigger(sent_list[e_s_index].text, e_trigger, e_s_start)
145 | event_list.append({
146 | 'event_id': e_id,
147 | 'start': e_start,
148 | 'trigger': e_trigger,
149 | 'type': e_type,
150 | 'subtype': e_subtype,
151 | 'realis': e_realis,
152 | 'sent_idx': e_s_index,
153 | 'sent_start': e_s_start
154 | })
155 | h_events.append(e_id)
156 | cluster_list.append({
157 | 'hopper_id': h_id,
158 | 'events': h_events
159 | })
160 | return {
161 | 'doc_id': doc_id,
162 | 'document': document,
163 | 'sentences': sentence_list,
164 | 'events': event_list,
165 | 'clusters': cluster_list
166 | }
167 |
168 | def split_dev(doc_list:list, valid_doc_num:int, valid_event_num:int, valid_chain_num:int):
169 | '''split dev set from full train set
170 | '''
171 | docs_id = [doc['doc_id'] for doc in doc_list]
172 | docs_event_num = np.asarray([len(doc['events']) for doc in doc_list])
173 | docs_event_num[docs_id.index('bolt-eng-DF-170-181109-47916')] += 2
174 | docs_event_num[docs_id.index('bolt-eng-DF-170-181109-48534')] += 1
175 | docs_cluster_num = np.asarray([len(doc['clusters']) for doc in doc_list])
176 | logger.info(f'Train & Dev set: Doc: {len(docs_id)} | Event: {docs_event_num.sum()} | Cluster: {docs_cluster_num.sum()}')
177 | train_docs, dev_docs = [], []
178 | logger.info(f'finding the correct split...')
179 | for indexs in combinations(range(len(docs_id)), valid_doc_num):
180 | indexs = np.asarray(indexs)
181 | if (
182 | docs_event_num[indexs].sum() == valid_event_num and
183 | docs_cluster_num[indexs].sum() == valid_chain_num
184 | ):
185 | logger.info(f'Done!')
186 | for idx, doc in enumerate(doc_list):
187 | if idx in indexs:
188 | dev_docs.append(doc)
189 | else:
190 | train_docs.append(doc)
191 | break
192 | return train_docs, dev_docs
193 |
194 | if __name__ == "__main__":
195 | docs = collections.defaultdict(list)
196 | kbp_sent_list = get_KBP_sents(os.path.join(args.sent_data_dir, SENT_FILE))
197 | for dataset in ['2015', '2016', '2017']:
198 | logger.info(f"parsing xml files in KBP {dataset} ...")
199 | for filename in get_KBP_filenames(dataset):
200 | doc_results = xml_parser(os.path.join(args.kbp_data_dir, filename.file_path), kbp_sent_list[filename.doc_id])
201 | docs[f'kbp_{dataset}'].append(doc_results)
202 | logger.info(f"Finished!")
203 | print_data_statistic(docs[f'kbp_{dataset}'], dataset)
204 | # split Dev set
205 | train_docs, dev_docs = split_dev(docs['kbp_2015'] + docs['kbp_2016'], 82, 2382, 1502)
206 | kbp_dataset = {
207 | 'train': train_docs,
208 | 'dev': dev_docs,
209 | 'test': docs['kbp_2017']
210 | }
211 | for doc_list in kbp_dataset.values():
212 | doc_list.sort(key=lambda x:x['doc_id'])
213 | for dataset in ['train', 'dev', 'test']:
214 | logger.info(f"saving {dataset} set ...")
215 | dataset_doc_list = kbp_dataset[dataset]
216 | print_data_statistic(dataset_doc_list, dataset)
217 | with open(f'{dataset}.json', 'wt', encoding='utf-8') as f:
218 | for doc in dataset_doc_list:
219 | f.write(json.dumps(doc) + '\n')
220 | logger.info(f"Finished!")
221 | # filter events & clusters
222 | for dataset in ['train', 'dev', 'test']:
223 | dataset_doc_list = filter_events(kbp_dataset[dataset], dataset)
224 | check_event_conflict(dataset_doc_list)
225 | print_data_statistic(dataset_doc_list, dataset)
226 | logger.info(f"saving filtered {dataset} set ...")
227 | with open(f'{dataset}_filtered.json', 'wt', encoding='utf-8') as f:
228 | for doc in dataset_doc_list:
229 | f.write(json.dumps(doc) + '\n')
230 | logger.info(f"Finished!")
231 |
--------------------------------------------------------------------------------
/data/kbp_vocab.txt:
--------------------------------------------------------------------------------
1 | ["have", "said", "has", "'s", "had", "did", "him", "think", "get", "know", "does", "'m", "see", "being", "going", "say", "go", "make", "made", "got", "want", "am", "take", "'re", "us", "apple", "told", "china", "obama", "u.s.", "'ve", "need", "according", "including", "let", "called", "used", "give", "pay", "saying", "went", "keep", "come", "believe", "killed", "convicted", "bush", "found", "left", "read", "find", "done", "came", "seems", "put", "having", "took", "doing", "work", "says", "use", "given", "united states", "feel", "thought", "arrested", "trying", "syria", "getting", "uk", "look", "making", "happened", "buy", "help", "elected", "israel", "posted", "tell", "hope", "agree", "wanted", "like", "become", "makes", "stop", "charged", "seen", "start", "known", "sent", "asked", "reported", "using", "russia", "allowed", "set", "understand", "involved", "mean", "based", "started", "taken", "died", "try", "run", "paid", "mandela", "live", "iraq", "held", "iran", "america", "heard", "saw", "looking", "happen", "working", "support", "taking", "leave", "expected", "released", "seem", "tried", "xinhua", "wants", "washington", "sentenced", "appear", "nokia", "coming", "remember", "announced", "guess", "eu", "love", "worked", "gave", "accused", "lost", "pakistan", "gets", "goes", "continue", "microsoft", "stay", "began", "served", "fired", "met", "became", "ask", "pardoned", "talking", "living", "gone", "egypt", "india", "hit", "move", "brought", "filed", "care", "following", "kill", "hear", "turned", "knew", "decided", "agreed", "needs", "vote", "snowden", "comes", "call", "allow", "led", "new york", "new york times", "sounds", "received", "knows", "leaving", "bring", "running", "killing", "wait", "north korea", "send", "talk", "die", "meet", "needed", "thinking", "cyprus", "means", "win", "congress", "show", "caused", "born", "explain", "clinton", "calling", "added", "wish", "giving", "florida", "london", "bought", "ordered", "europe", "issued", "stand", "spend", "thank", "shot", "google", "hold", "senate", "forced", "provide", "white house", "seeing", "wonder", "change", "related", "married", "claimed", "end", "texas", "takes", "philippines", "spent", "deal", "moved", "wrote", "foxconn", "france", "looks", "cut", "telling", "un", "follow", "morsi", "sell", "turn", "speak", "paying", "passed", "condensed", "won", "scotland", "usa", "nelson mandela", "showed", "committed", "face", "appointed", "injured", "ruled", "barack obama", "britain", "ukraine", "lose", "buying", "hate", "fighting", "granted", "denied", "confirmed", "germany", "happens", "kept", "considered", "sandusky", "chun", "failed", "remain", "helped", "works", "serve", "gives", "'d", "affected", "caught", "afghanistan", "changed", "leading", "receive", "owned", "cause", "created", "supreme court", "starting", "lead", "believed", "thinks", "sold", "japan", "trump", "followed", "mention", "supposed", "expect", "serving", "realize", "planned", "detained", "refused", "return", "remains", "felt", "extradited", "protect", "doubt", "blame", "moving", "built", "appeared", "offered", "reached", "include", "speaking", "waiting", "carry", "declined", "included", "bangladesh", "istanbul", "create", "spain", "california", "beijing", "decide", "supporting", "helping", "considering", "provided", "lived", "avoid", "rejected", "attacked", "arrived", "executed", "described", "named", "voted", "stopped", "claiming", "vietnam", "shows", "appears", "ended", "check", "returned", "discuss", "asking", "australia", "learn", "bet", "knowing", "imagine", "adding", "prove", "deserve", "remained", "consider", "putting", "save", "watch", "continued", "fight", "broke", "accept", "expressed", "ran", "join", "selling", "supported", "wanting", "written", "watching", "suggest", "pardon", "holding", "address", "cairo", "claim", "learned", "turkey", "pass", "played", "united nations", "carried", "signed", "seeking", "building", "grow", "libya", "south africa", "prevent", "losing", "assad", "treated", "meant", "scheduled", "driving", "happening", "add", "lying", "sound", "disagree", "raped", "travel", "mexico", "joined", "becoming", "stated", "required", "defend", "canada", "seemed", "haiyan", "afford", "attempted", "samsung", "gotten", "hired", "steve", "italy", "nominated", "register", "cost", "planning", "shown", "chang", "sought", "putin", "missing", "represent", "cover", "army", "published", "declared", "intended", "worry", "admitted", "looked", "growing", "sending", "urged", "ignore", "steve jobs", "moscow", "quoted", "raised", "post", "build", "zimmerman", "muslim brotherhood", "brazil", "reporting", "fined", "retired", "demanding", "gm", "covered", "clicking", "reading", "begin", "forget", "acting", "hoping", "resigned", "sitting", "argue", "compared", "sued", "fuck", "breaking", "drive", "cia", "dropped", "walk", "admit", "middle east", "ensure", "walked", "england", "watched", "beat", "force", "opposed", "south korea", "lives", "pick", "regarding", "justice department"]
--------------------------------------------------------------------------------
/data/run_info.log:
--------------------------------------------------------------------------------
1 | 2022/04/09 04:50:42 - INFO - Convert - parsing xml files in KBP 2015 ...
2 | 2022/04/09 04:50:42 - WARNING - Convert - update: [EX-](3) - [EX-SOUTH](8)
3 | 2022/04/09 04:50:42 - WARNING - Convert - update: [manufacture](11) - [manufacturer](12)
4 | 2022/04/09 04:50:43 - WARNING - Convert - update: [explanation](11) - [explanations](12)
5 | 2022/04/09 04:50:43 - WARNING - Convert - update: [explanation](11) - [explanations](12)
6 | 2022/04/09 04:50:43 - WARNING - Convert - update: [explanation](11) - [explanations](12)
7 | 2022/04/09 04:50:43 - INFO - Convert - Finished!
8 | 2022/04/09 04:50:43 - INFO - Utils - KBP 2015 - Doc: 648 | Event: 18736 | Cluster: 11603 | Singleton: 8484
9 | 2022/04/09 04:50:43 - INFO - Convert - parsing xml files in KBP 2016 ...
10 | 2022/04/09 04:50:44 - INFO - Convert - Finished!
11 | 2022/04/09 04:50:44 - INFO - Utils - KBP 2016 - Doc: 169 | Event: 4155 | Cluster: 3191 | Singleton: 2709
12 | 2022/04/09 04:50:44 - INFO - Convert - parsing xml files in KBP 2017 ...
13 | 2022/04/09 04:50:44 - WARNING - Convert - update: [-](1) - [-1996](5)
14 | 2022/04/09 04:50:45 - INFO - Convert - Finished!
15 | 2022/04/09 04:50:45 - INFO - Utils - KBP 2017 - Doc: 167 | Event: 4375 | Cluster: 2963 | Singleton: 2358
16 |
17 | 2022/04/09 04:50:45 - INFO - Convert - Train & Dev set: Doc: 817 | Event: 22894 | Cluster: 14794
18 | 2022/04/09 04:50:45 - INFO - Convert - finding the correct split...
19 | 2022/04/09 05:09:33 - INFO - Convert - Done!
20 | 2022/04/09 05:09:33 - INFO - Convert - saving train set ...
21 | 2022/04/09 05:09:33 - INFO - Utils - KBP train - Doc: 735 | Event: 20509 | Cluster: 13292 | Singleton: 10067
22 | 2022/04/09 05:09:33 - INFO - Convert - Finished!
23 | 2022/04/09 05:09:33 - INFO - Convert - saving dev set ...
24 | 2022/04/09 05:09:33 - INFO - Utils - KBP dev - Doc: 82 | Event: 2382 | Cluster: 1502 | Singleton: 1126
25 | 2022/04/09 05:09:33 - INFO - Convert - Finished!
26 | 2022/04/09 05:09:33 - INFO - Convert - saving test set ...
27 | 2022/04/09 05:09:33 - INFO - Utils - KBP test - Doc: 167 | Event: 4375 | Cluster: 2963 | Singleton: 2358
28 | 2022/04/09 05:09:33 - INFO - Convert - Finished!
29 |
30 | 2022/04/09 05:09:33 - INFO - Filter - KBP train event filtered: 1629 (same 1621 / overlapping 8)
31 | 2022/04/09 05:09:33 - INFO - Filter - KBP train cluster filtered: 951
32 | 2022/04/09 05:09:33 - INFO - Utils - KBP train - Doc: 735 | Event: 18880 | Cluster: 12341 | Singleton: 9369
33 | 2022/04/09 05:09:33 - INFO - Convert - saving filtered train set ...
34 | 2022/04/09 05:09:33 - INFO - Convert - Finished!
35 | 2022/04/09 05:09:33 - INFO - Filter - KBP dev event filtered: 200 (same 198 / overlapping 2)
36 | 2022/04/09 05:09:33 - INFO - Filter - KBP dev cluster filtered: 100
37 | 2022/04/09 05:09:33 - INFO - Utils - KBP dev - Doc: 82 | Event: 2182 | Cluster: 1402 | Singleton: 1051
38 | 2022/04/09 05:09:33 - INFO - Convert - saving filtered dev set ...
39 | 2022/04/09 05:09:33 - INFO - Convert - Finished!
40 | 2022/04/09 05:09:33 - INFO - Filter - KBP test event filtered: 379 (same 378 / overlapping 1)
41 | 2022/04/09 05:09:33 - INFO - Filter - KBP test cluster filtered: 256
42 | 2022/04/09 05:09:33 - INFO - Utils - KBP test - Doc: 167 | Event: 3996 | Cluster: 2707 | Singleton: 2161
43 | 2022/04/09 05:09:33 - INFO - Convert - saving filtered test set ...
44 | 2022/04/09 05:09:33 - INFO - Convert - Finished!
--------------------------------------------------------------------------------
/data/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
4 | datefmt='%Y/%m/%d %H:%M:%S',
5 | level=logging.INFO)
6 |
7 | logger = logging.getLogger("Utils")
8 | filter_logger = logging.getLogger("Filter")
9 |
10 | def print_data_statistic(doc_list, dataset=''):
11 | doc_num = len(doc_list)
12 | event_num = sum([len(doc['events']) for doc in doc_list])
13 | cluster_num = sum([len(doc['clusters']) for doc in doc_list])
14 | singleton_num = sum([1 if len(cluster['events']) == 1 else 0
15 | for doc in doc_list for cluster in doc['clusters']])
16 | logger.info(f"KBP {dataset} - Doc: {doc_num} | Event: {event_num} | Cluster: {cluster_num} | Singleton: {singleton_num}")
17 |
18 | def check_event_conflict(doc_list):
19 | for doc in doc_list:
20 | event_list = doc['events']
21 | event_list.sort(key=lambda x:x['start'])
22 | if len(event_list) < 2:
23 | continue
24 | for idx in range(len(event_list)-1):
25 | if (
26 | (
27 | event_list[idx]['start'] == event_list[idx+1]['start'] and
28 | event_list[idx]['trigger'] == event_list[idx+1]['trigger']
29 | ) or
30 | (
31 | event_list[idx]['start'] + len(event_list[idx]['trigger']) >
32 | event_list[idx+1]['start']
33 | )
34 | ):
35 | logger.error('{}: ({})[{}] VS ({})[{}]'.format(doc['doc_id'],
36 | event_list[idx]['start'], event_list[idx]['trigger'],
37 | event_list[idx+1]['start'], event_list[idx+1]['trigger']))
38 |
39 | def filter_events(doc_list, dataset=''):
40 | same = 0
41 | overlapping = 0
42 | cluster_num_filtered = 0
43 | for doc in doc_list:
44 | event_list = doc['events']
45 | event_list.sort(key=lambda x:x['start'])
46 | event_filtered = []
47 | if len(event_list) < 2:
48 | continue
49 | new_event_list, should_add = [], True
50 | for idx in range(len(event_list)-1):
51 | if (event_list[idx]['start'] == event_list[idx+1]['start'] and
52 | event_list[idx]['trigger'] == event_list[idx+1]['trigger']
53 | ):
54 | event_filtered.append(event_list[idx]['event_id'])
55 | same += 1
56 | continue
57 | if (event_list[idx]['start'] + len(event_list[idx]['trigger']) >
58 | event_list[idx+1]['start']
59 | ):
60 | overlapping += 1
61 | if len(event_list[idx]['trigger']) < len(event_list[idx+1]['trigger']):
62 | new_event_list.append(event_list[idx])
63 | should_add = False
64 | else:
65 | event_filtered.append(event_list[idx]['event_id'])
66 | continue
67 | if should_add:
68 | new_event_list.append(event_list[idx])
69 | else:
70 | event_filtered.append(event_list[idx]['event_id'])
71 | should_add = True
72 | if should_add:
73 | new_event_list.append(event_list[-1])
74 | doc['events'] = new_event_list
75 | new_clusters = []
76 | for cluster in doc['clusters']:
77 | new_events = [event_id for event_id in cluster['events'] if event_id not in event_filtered]
78 | if len(new_events) == 0:
79 | cluster_num_filtered += 1
80 | continue
81 | new_clusters.append({
82 | 'hopper_id': cluster['hopper_id'],
83 | 'events': new_events
84 | })
85 | doc['clusters'] = new_clusters
86 | filter_logger.info(f'KBP {dataset} event filtered: {same + overlapping} (same {same} / overlapping {overlapping})')
87 | filter_logger.info(f'KBP {dataset} cluster filtered: {cluster_num_filtered}')
88 | return doc_list
--------------------------------------------------------------------------------
/download_pt_models.sh:
--------------------------------------------------------------------------------
1 | mkdir -p ../PT_MODELS/bert-base-cased/
2 | wget -P ../PT_MODELS/bert-base-cased/ https://huggingface.co/bert-base-cased/resolve/main/pytorch_model.bin
3 | wget -P ../PT_MODELS/bert-base-cased/ https://huggingface.co/bert-base-cased/resolve/main/README.md
4 | wget -P ../PT_MODELS/bert-base-cased/ https://huggingface.co/bert-base-cased/resolve/main/config.json
5 | wget -P ../PT_MODELS/bert-base-cased/ https://huggingface.co/bert-base-cased/resolve/main/tokenizer.json
6 | wget -P ../PT_MODELS/bert-base-cased/ https://huggingface.co/bert-base-cased/resolve/main/tokenizer_config.json
7 | wget -P ../PT_MODELS/bert-base-cased/ https://huggingface.co/bert-base-cased/resolve/main/vocab.txt
8 | mkdir -p ../PT_MODELS/bert-large-cased/
9 | wget -P ../PT_MODELS/bert-large-cased/ https://huggingface.co/bert-large-cased/resolve/main/pytorch_model.bin
10 | wget -P ../PT_MODELS/bert-large-cased/ https://huggingface.co/bert-large-cased/resolve/main/README.md
11 | wget -P ../PT_MODELS/bert-large-cased/ https://huggingface.co/bert-large-cased/resolve/main/config.json
12 | wget -P ../PT_MODELS/bert-large-cased/ https://huggingface.co/bert-large-cased/resolve/main/tokenizer.json
13 | wget -P ../PT_MODELS/bert-large-cased/ https://huggingface.co/bert-large-cased/resolve/main/tokenizer_config.json
14 | wget -P ../PT_MODELS/bert-large-cased/ https://huggingface.co/bert-large-cased/resolve/main/vocab.txt
15 | mkdir -p ../PT_MODELS/roberta-base/
16 | wget -P ../PT_MODELS/roberta-base/ https://huggingface.co/roberta-base/resolve/main/pytorch_model.bin
17 | wget -P ../PT_MODELS/roberta-base/ https://huggingface.co/roberta-base/resolve/main/README.md
18 | wget -P ../PT_MODELS/roberta-base/ https://huggingface.co/roberta-base/resolve/main/config.json
19 | wget -P ../PT_MODELS/roberta-base/ https://huggingface.co/roberta-base/resolve/main/dict.txt
20 | wget -P ../PT_MODELS/roberta-base/ https://huggingface.co/roberta-base/resolve/main/merges.txt
21 | wget -P ../PT_MODELS/roberta-base/ https://huggingface.co/roberta-base/resolve/main/tokenizer.json
22 | wget -P ../PT_MODELS/roberta-base/ https://huggingface.co/roberta-base/resolve/main/vocab.json
23 | mkdir -p ../PT_MODELS/roberta-large/
24 | wget -P ../PT_MODELS/roberta-large/ https://huggingface.co/roberta-large/resolve/main/pytorch_model.bin
25 | wget -P ../PT_MODELS/roberta-large/ https://huggingface.co/roberta-large/resolve/main/README.md
26 | wget -P ../PT_MODELS/roberta-large/ https://huggingface.co/roberta-large/resolve/main/config.json
27 | wget -P ../PT_MODELS/roberta-large/ https://huggingface.co/roberta-large/resolve/main/merges.txt
28 | wget -P ../PT_MODELS/roberta-large/ https://huggingface.co/roberta-large/resolve/main/tokenizer.json
29 | wget -P ../PT_MODELS/roberta-large/ https://huggingface.co/roberta-large/resolve/main/vocab.json
30 | mkdir -p ../PT_MODELS/SpanBERT/spanbert-base-cased
31 | wget -P ../PT_MODELS/SpanBERT/spanbert-base-cased https://huggingface.co/SpanBERT/spanbert-base-cased/resolve/main/pytorch_model.bin
32 | wget -P ../PT_MODELS/SpanBERT/spanbert-base-cased https://huggingface.co/SpanBERT/spanbert-base-cased/resolve/main/config.json
33 | wget -P ../PT_MODELS/SpanBERT/spanbert-base-cased https://huggingface.co/SpanBERT/spanbert-base-cased/resolve/main/vocab.txt
34 | mkdir -p ../PT_MODELS/SpanBERT/spanbert-large-cased
35 | wget -P ../PT_MODELS/SpanBERT/spanbert-large-cased https://huggingface.co/SpanBERT/spanbert-large-cased/resolve/main/pytorch_model.bin
36 | wget -P ../PT_MODELS/SpanBERT/spanbert-large-cased https://huggingface.co/SpanBERT/spanbert-large-cased/resolve/main/config.json
37 | wget -P ../PT_MODELS/SpanBERT/spanbert-large-cased https://huggingface.co/SpanBERT/spanbert-large-cased/resolve/main/vocab.txt
38 | mkdir -p ../PT_MODELS/allenai/longformer-large-4096/
39 | wget -P ../PT_MODELS/allenai/longformer-large-4096/ https://huggingface.co/allenai/longformer-large-4096/resolve/main/pytorch_model.bin
40 | wget -P ../PT_MODELS/allenai/longformer-large-4096/ https://huggingface.co/allenai/longformer-large-4096/resolve/main/config.json
41 | wget -P ../PT_MODELS/allenai/longformer-large-4096/ https://huggingface.co/allenai/longformer-large-4096/resolve/main/merges.txt
42 | wget -P ../PT_MODELS/allenai/longformer-large-4096/ https://huggingface.co/allenai/longformer-large-4096/resolve/main/tokenizer.json
43 | wget -P ../PT_MODELS/allenai/longformer-large-4096/ https://huggingface.co/allenai/longformer-large-4096/resolve/main/vocab.json
44 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.22.3
2 | torch==1.11.0
3 | seqeval==1.2.2
4 | scikit-learn==1.1.2
5 | allennlp==2.9.2
6 | transformers==4.17.0
7 |
--------------------------------------------------------------------------------
/src/ablation/arg.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def parse_args():
4 | parser = argparse.ArgumentParser()
5 |
6 | # Required parameters
7 | parser.add_argument("--output_dir", default=None, type=str, required=True,
8 | help="The output directory where the model checkpoints and predictions will be written.",
9 | )
10 | parser.add_argument("--train_file", default=None, type=str, required=True, help="The input training file.")
11 | parser.add_argument("--dev_file", default=None, type=str, required=True, help="The input evaluation file.")
12 | parser.add_argument("--test_file", default=None, type=str, required=True, help="The input testing file.")
13 |
14 | parser.add_argument("--model_type",
15 | default="longformer", type=str, required=False
16 | )
17 | parser.add_argument("--model_checkpoint",
18 | default="allenai/longformer-base-4096", type=str, required=False,
19 | help="Path to pretrained model or model identifier from huggingface.co/models",
20 | )
21 | parser.add_argument("--max_seq_length", default=4096, type=int, required=False)
22 | parser.add_argument("--matching_style", default="multi", type=str, required=True,
23 | help="how to match two event representations"
24 | )
25 |
26 | parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
27 | parser.add_argument("--do_test", action="store_true", help="Whether to run eval on the test set.")
28 | parser.add_argument("--do_predict", action="store_true", help="Whether to save predicted labels.")
29 | parser.add_argument("--do_analysis", action="store_true", help="Whether to do analysis on the test set.")
30 |
31 | # Other parameters
32 | parser.add_argument("--cache_dir", default=None, type=str,
33 | help="Where do you want to store the pre-trained models downloaded from s3"
34 | )
35 | parser.add_argument("--topic_model", default='stm', type=str,
36 | choices=['stm', 'stm_bn', 'vmf']
37 | )
38 | parser.add_argument("--topic_dim", default=32, type=int)
39 | parser.add_argument("--topic_inter_map", default=64, type=int)
40 | parser.add_argument("--mention_encoder_type", default="bert", type=str)
41 | parser.add_argument("--mention_encoder_checkpoint",
42 | default="bert-large-cased", type=str,
43 | help="Path to pretrained model or model identifier from huggingface.co/models",
44 | )
45 | parser.add_argument("--include_mention_context", action="store_true")
46 | parser.add_argument("--max_mention_length", default=512, type=int)
47 | parser.add_argument("--add_contrastive_loss", action="store_true")
48 | parser.add_argument("--softmax_loss", default='ce', type=str,
49 | help="The loss function for softmax model.",
50 | choices=['lsr', 'focal', 'ce']
51 | )
52 |
53 | parser.add_argument("--learning_rate", default=1e-5, type=float, help="The initial learning rate for Adam.")
54 | parser.add_argument("--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform.")
55 | parser.add_argument("--batch_size", default=4, type=int)
56 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
57 |
58 | parser.add_argument("--adam_beta1", default=0.9, type=float,
59 | help="Epsilon for Adam optimizer."
60 | )
61 | parser.add_argument("--adam_beta2", default=0.98, type=float,
62 | help="Epsilon for Adam optimizer."
63 | )
64 | parser.add_argument("--adam_epsilon", default=1e-8, type=float,
65 | help="Epsilon for Adam optimizer."
66 | )
67 | parser.add_argument("--warmup_proportion", default=0.1, type=float,
68 | help="Proportion of training to perform linear learning rate warmup for,E.g., 0.1 = 10% of training."
69 | )
70 | parser.add_argument("--weight_decay", default=0.01, type=float,
71 | help="Weight decay if we apply some."
72 | )
73 | args = parser.parse_args()
74 | return args
--------------------------------------------------------------------------------
/src/ablation/chunk_global_encoder.sh:
--------------------------------------------------------------------------------
1 | export OUTPUT_DIR=./ChunkBertEncoder_M-multi-cosine_results/
2 |
3 | python3 chunk_global_encoder.py \
4 | --output_dir=$OUTPUT_DIR \
5 | --model_type=bert \
6 | --model_checkpoint=../../../PT_MODELS/bert-large-cased/ \
7 | --mention_encoder_type=bert \
8 | --mention_encoder_checkpoint=../../../PT_MODELS/bert-base-cased/ \
9 | --topic_model=vmf \
10 | --topic_dim=32 \
11 | --topic_inter_map=64 \
12 | --train_file=../../data/train_filtered.json \
13 | --dev_file=../../data/dev_filtered.json \
14 | --test_file=../../data/test_filtered.json \
15 | --max_seq_length=512 \
16 | --max_mention_length=256 \
17 | --learning_rate=1e-5 \
18 | --matching_style=multi_cosine \
19 | --softmax_loss=ce \
20 | --num_train_epochs=50 \
21 | --batch_size=1 \
22 | --do_train \
23 | --warmup_proportion=0. \
24 | --seed=42
--------------------------------------------------------------------------------
/src/ablation/without_global_encoder.sh:
--------------------------------------------------------------------------------
1 | export OUTPUT_DIR=./NoGlobal_M-multi-cosine_results/
2 |
3 | python3 without_global_encoder.py \
4 | --output_dir=$OUTPUT_DIR \
5 | --mention_encoder_type=bert \
6 | --mention_encoder_checkpoint=../../../PT_MODELS/bert-base-cased/ \
7 | --topic_model=vmf \
8 | --topic_dim=32 \
9 | --topic_inter_map=64 \
10 | --train_file=../../data/train_filtered.json \
11 | --dev_file=../../data/dev_filtered.json \
12 | --test_file=../../data/test_filtered.json \
13 | --max_mention_length=256 \
14 | --learning_rate=1e-5 \
15 | --matching_style=multi_cosine \
16 | --softmax_loss=ce \
17 | --num_train_epochs=50 \
18 | --batch_size=1 \
19 | --do_train \
20 | --warmup_proportion=0. \
21 | --seed=42
--------------------------------------------------------------------------------
/src/analysis/run_analysis.py:
--------------------------------------------------------------------------------
1 | from sklearn.metrics import classification_report
2 | from collections import defaultdict
3 | import sys
4 | sys.path.append('../../')
5 | from src.analysis.utils import get_event_pair_set
6 |
7 | gold_coref_file = '../../data/test.json'
8 | pred_coref_file = 'MaskTopicBN_M-multi-cosine.json'
9 |
10 | def all_metrics(gold_coref_file, pred_coref_file):
11 | gold_coref_results, pred_coref_results = get_event_pair_set(gold_coref_file, pred_coref_file)
12 | all_event_pairs = [] # (gold_coref, pred_coref)
13 | for doc_id, gold_coref_result_dict in gold_coref_results.items():
14 | # {e_i_start-e_j_start: (coref, sent_dist, e_i_link_len, e_j_link_len)}
15 | gold_unrecognized_event_pairs, gold_recognized_event_pairs = (
16 | gold_coref_result_dict['unrecognized_event_pairs'],
17 | gold_coref_result_dict['recognized_event_pairs']
18 | )
19 | pred_coref_result_dict = pred_coref_results[doc_id]
20 | # {e_i_start-e_j_start: (coref, sent_dist, e_i_link_len, e_j_link_len)}
21 | pred_recognized_event_pairs, pred_wrong_event_pairs = (
22 | pred_coref_result_dict['recognized_event_pairs'],
23 | pred_coref_result_dict['wrong_event_pairs']
24 | )
25 | for pair_results in gold_unrecognized_event_pairs.values():
26 | all_event_pairs.append([str(pair_results[0]), '2'])
27 | for pair_id, pair_results in gold_recognized_event_pairs.items():
28 | all_event_pairs.append([str(pair_results[0]), str(pred_recognized_event_pairs[pair_id][0])])
29 | for pair_id, pair_results in pred_wrong_event_pairs.items():
30 | all_event_pairs.append(['0', str(pair_results[0])])
31 | y_true, y_pred = [res[0] for res in all_event_pairs], [res[1] for res in all_event_pairs]
32 | metrics = {'ALL': classification_report(y_true=y_true, y_pred=y_pred, output_dict=True)['1']}
33 | return metrics
34 |
35 | def different_distance_metrics(gold_coref_file, pred_coref_file, adj_distance=3):
36 | gold_coref_results, pred_coref_results = get_event_pair_set(gold_coref_file, pred_coref_file)
37 | same_event_pairs, adj_event_pairs, far_event_pairs = [], [], []
38 | for doc_id, gold_coref_result_dict in gold_coref_results.items():
39 | # {e_i_start-e_j_start: (coref, sent_dist, e_i_link_len, e_j_link_len)}
40 | gold_unrecognized_event_pairs, gold_recognized_event_pairs = (
41 | gold_coref_result_dict['unrecognized_event_pairs'],
42 | gold_coref_result_dict['recognized_event_pairs']
43 | )
44 | pred_coref_result_dict = pred_coref_results[doc_id]
45 | # {e_i_start-e_j_start: (coref, sent_dist, e_i_link_len, e_j_link_len)}
46 | pred_recognized_event_pairs, pred_wrong_event_pairs = (
47 | pred_coref_result_dict['recognized_event_pairs'],
48 | pred_coref_result_dict['wrong_event_pairs']
49 | )
50 | for pair_results in gold_unrecognized_event_pairs.values():
51 | sent_dist = pair_results[1]
52 | pair_coref = [str(pair_results[0]), '2']
53 | if sent_dist == 0: # same sentence
54 | same_event_pairs.append(pair_coref)
55 | elif sent_dist < adj_distance: # adjacent sentence
56 | adj_event_pairs.append(pair_coref)
57 | else: # far sentence
58 | far_event_pairs.append(pair_coref)
59 | for pair_id, pair_results in gold_recognized_event_pairs.items():
60 | sent_dist = pair_results[1]
61 | pair_coref = [str(pair_results[0]), str(pred_recognized_event_pairs[pair_id][0])]
62 | if sent_dist == 0: # same sentence
63 | same_event_pairs.append(pair_coref)
64 | elif sent_dist < adj_distance: # adjacent sentence
65 | adj_event_pairs.append(pair_coref)
66 | else: # far sentence
67 | far_event_pairs.append(pair_coref)
68 | for pair_id, pair_results in pred_wrong_event_pairs.items():
69 | sent_dist = pair_results[1]
70 | pair_coref = ['0', str(pair_results[0])]
71 | if sent_dist == 0: # same sentence
72 | same_event_pairs.append(pair_coref)
73 | elif sent_dist < adj_distance: # adjacent sentence
74 | adj_event_pairs.append(pair_coref)
75 | else: # far sentence
76 | far_event_pairs.append(pair_coref)
77 | metrics = {}
78 | y_true, y_pred = [res[0] for res in same_event_pairs], [res[1] for res in same_event_pairs]
79 | metrics['SAME'] = classification_report(y_true=y_true, y_pred=y_pred, output_dict=True)['1']
80 | y_true, y_pred = [res[0] for res in adj_event_pairs], [res[1] for res in adj_event_pairs]
81 | metrics['ADJ'] = classification_report(y_true=y_true, y_pred=y_pred, output_dict=True)['1']
82 | y_true, y_pred = [res[0] for res in far_event_pairs], [res[1] for res in far_event_pairs]
83 | metrics['FAR'] = classification_report(y_true=y_true, y_pred=y_pred, output_dict=True)['1']
84 | return metrics
85 |
86 | def main_link_metrics(gold_coref_file, pred_coref_file, main_link_length=5, mode='ge'):
87 | assert mode in ['g', 'ge', 'e', 'le', 'l']
88 | gold_coref_results, pred_coref_results = get_event_pair_set(gold_coref_file, pred_coref_file)
89 | main_link_event_pairs, singleton_event_pairs = [], defaultdict(list)
90 | for doc_id, gold_coref_result_dict in gold_coref_results.items():
91 | # {e_i_start-e_j_start: (coref, sent_dist, e_i_link_len, e_j_link_len)}
92 | gold_unrecognized_event_pairs, gold_recognized_event_pairs = (
93 | gold_coref_result_dict['unrecognized_event_pairs'],
94 | gold_coref_result_dict['recognized_event_pairs']
95 | )
96 | pred_coref_result_dict = pred_coref_results[doc_id]
97 | # {e_i_start-e_j_start: (coref, sent_dist, e_i_link_len, e_j_link_len)}
98 | pred_recognized_event_pairs, pred_wrong_event_pairs = (
99 | pred_coref_result_dict['recognized_event_pairs'],
100 | pred_coref_result_dict['wrong_event_pairs']
101 | )
102 |
103 | for pair_id, pair_results in gold_recognized_event_pairs.items():
104 | e_starts = pair_id.split('-')
105 | e_i_link_len, e_j_link_len = pair_results[2], pair_results[3]
106 | pair_coref = [str(pair_results[0]), str(pred_recognized_event_pairs[pair_id][0])]
107 | if e_i_link_len == 1:
108 | singleton_event_pairs[e_starts[0]].append(pair_coref[0] == pair_coref[1])
109 | if e_j_link_len == 1:
110 | singleton_event_pairs[e_starts[1]].append(pair_coref[0] == pair_coref[1])
111 | if mode == 'g':
112 | if e_i_link_len > main_link_length or e_j_link_len > main_link_length:
113 | main_link_event_pairs.append(pair_coref)
114 | elif mode == 'ge':
115 | if e_i_link_len >= main_link_length or e_j_link_len >= main_link_length:
116 | main_link_event_pairs.append(pair_coref)
117 | elif mode == 'e':
118 | if e_i_link_len == main_link_length or e_j_link_len == main_link_length:
119 | main_link_event_pairs.append(pair_coref)
120 | elif mode == 'le':
121 | if (e_i_link_len <= main_link_length and e_i_link_len > 1) or (e_j_link_len <= main_link_length and e_j_link_len > 1):
122 | main_link_event_pairs.append(pair_coref)
123 | elif mode == 'l':
124 | if (e_i_link_len < main_link_length and e_i_link_len > 1) or (e_j_link_len < main_link_length and e_j_link_len > 1):
125 | main_link_event_pairs.append(pair_coref)
126 | for pair_id, pair_results in gold_unrecognized_event_pairs.items():
127 | e_starts = pair_id.split('-')
128 | e_i_link_len, e_j_link_len = pair_results[2], pair_results[3]
129 | pair_coref = [str(pair_results[0]), '2']
130 | if e_i_link_len == 1 and e_starts[0] not in singleton_event_pairs:
131 | singleton_event_pairs[e_starts[0]].append(False)
132 | if e_j_link_len == 1 and e_starts[1] not in singleton_event_pairs:
133 | singleton_event_pairs[e_starts[1]].append(False)
134 | if mode == 'g':
135 | if e_i_link_len > main_link_length or e_j_link_len > main_link_length:
136 | main_link_event_pairs.append(pair_coref)
137 | elif mode == 'ge':
138 | if e_i_link_len >= main_link_length or e_j_link_len >= main_link_length:
139 | main_link_event_pairs.append(pair_coref)
140 | elif mode == 'e':
141 | if e_i_link_len == main_link_length or e_j_link_len == main_link_length:
142 | main_link_event_pairs.append(pair_coref)
143 | elif mode == 'le':
144 | if (e_i_link_len <= main_link_length and e_i_link_len > 1) or (e_j_link_len <= main_link_length and e_j_link_len > 1):
145 | main_link_event_pairs.append(pair_coref)
146 | elif mode == 'l':
147 | if (e_i_link_len < main_link_length and e_i_link_len > 1) or (e_j_link_len < main_link_length and e_j_link_len > 1):
148 | main_link_event_pairs.append(pair_coref)
149 |
150 | for pair_id, pair_results in pred_wrong_event_pairs.items():
151 | e_starts = pair_id.split('-')
152 | if e_starts[0] in singleton_event_pairs:
153 | singleton_event_pairs[e_starts[0]].append(pair_results[0] == 0)
154 | if e_starts[1] in singleton_event_pairs:
155 | singleton_event_pairs[e_starts[1]].append(pair_results[0] == 0)
156 |
157 | mode_str = {'g': '>', 'ge': '>=', 'e': '==', 'le': '<=', 'l': '<'}[mode]
158 | metrics = {}
159 | y_true, y_pred = [res[0] for res in main_link_event_pairs], [res[1] for res in main_link_event_pairs]
160 | metrics[f'Main Link ({mode_str}{main_link_length})'] = classification_report(y_true=y_true, y_pred=y_pred, output_dict=True)['1']
161 | wrong_num = sum([False in singleton_coref_correct for singleton_coref_correct in singleton_event_pairs.values()])
162 | print(wrong_num)
163 | print(len(singleton_event_pairs))
164 | metrics['Singleton Acc'] = (len(singleton_event_pairs) - wrong_num) / len(singleton_event_pairs) * 100
165 | return metrics
166 |
167 |
168 | # print(all_metrics(gold_coref_file, pred_coref_file))
169 | # print(different_distance_metrics(gold_coref_file, pred_coref_file))
170 | print(main_link_metrics(gold_coref_file, pred_coref_file, main_link_length=10, mode='ge'))
171 |
--------------------------------------------------------------------------------
/src/analysis/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from collections import namedtuple, defaultdict
4 |
5 | Sentence = namedtuple("Sentence", ["start", "text"])
6 | kbp_sent_dic = defaultdict(list) # {filename: [Sentence]}
7 | with open(os.path.join('../../data/kbp_sent.txt'), 'rt', encoding='utf-8') as sents:
8 | for line in sents:
9 | doc_id, start, text = line.strip().split('\t')
10 | kbp_sent_dic[doc_id].append(Sentence(int(start), text))
11 |
12 | def get_event_sent_idx(e_start, e_end, sents):
13 | for sent_idx, sent in enumerate(sents):
14 | sent_end = sent.start + len(sent.text) - 1
15 | if e_start >= sent.start and e_end <= sent_end:
16 | return sent_idx
17 | return None
18 |
19 | def get_gold_corefs(gold_test_file):
20 |
21 | def _get_event_cluster_id_and_link_len(event_id, clusters):
22 | for cluster in clusters:
23 | if event_id in cluster['events']:
24 | return cluster['hopper_id'], len(cluster['events'])
25 | return None, None
26 |
27 | gold_dict = {}
28 | with open(gold_test_file, 'rt', encoding='utf-8') as f:
29 | for line in f:
30 | sample = json.loads(line.strip())
31 | clusters = sample['clusters']
32 | events = sample['events']
33 | event_pairs = {} # {e_i_start-e_j_start: (coref, sent_dist, e_i_link_len, e_j_link_len)}
34 | for i in range(len(events) - 1):
35 | e_i_start = events[i]['start']
36 | e_i_cluster_id, e_i_link_len = _get_event_cluster_id_and_link_len(events[i]['event_id'], clusters)
37 | assert e_i_cluster_id is not None
38 | e_i_sent_idx = events[i]['sent_idx']
39 | for j in range(i + 1, len(events)):
40 | e_j_start = events[j]['start']
41 | e_j_cluster_id, e_j_link_len = _get_event_cluster_id_and_link_len(events[j]['event_id'], clusters)
42 | assert e_j_cluster_id is not None
43 | e_j_sent_idx = events[j]['sent_idx']
44 | event_pairs[f'{e_i_start}-{e_j_start}'] = [
45 | 1 if e_i_cluster_id == e_j_cluster_id else 0, abs(int(e_i_sent_idx) - int(e_j_sent_idx)), e_i_link_len, e_j_link_len
46 | ]
47 | gold_dict[sample['doc_id']] = event_pairs
48 | return gold_dict
49 |
50 | def get_pred_coref_results(pred_file_path):
51 | pred_dict = {}
52 | with open(pred_file_path, 'rt', encoding='utf-8') as f:
53 | for line in f:
54 | sample = json.loads(line.strip())
55 | sents = kbp_sent_dic[sample['doc_id']]
56 | events = sample['events']
57 | pred_labels = sample['pred_label']
58 | event_pairs = {} # {e_i_start-e_j_start: (coref, sent_dist, e_i_link_len, e_j_link_len)}
59 | event_pair_idx = -1
60 | for i in range(len(events) - 1):
61 | e_i_start = events[i]['start']
62 | e_i_sent_idx = get_event_sent_idx(events[i]['start'], events[i]['end'], sents)
63 | assert e_i_sent_idx is not None
64 | for j in range(i + 1, len(events)):
65 | event_pair_idx += 1
66 | e_j_start = events[j]['start']
67 | e_j_sent_idx = get_event_sent_idx(events[j]['start'], events[j]['end'], sents)
68 | assert e_j_sent_idx is not None
69 | event_pairs[f'{e_i_start}-{e_j_start}'] = [pred_labels[event_pair_idx], abs(int(e_i_sent_idx) - int(e_j_sent_idx)), 0, 0]
70 | pred_dict[sample['doc_id']] = event_pairs
71 | return pred_dict
72 |
73 | def get_event_pair_set(gold_coref_file, pred_coref_file):
74 |
75 | gold_coref_results = get_gold_corefs(gold_coref_file)
76 | pred_coref_results = get_pred_coref_results(pred_coref_file)
77 |
78 | new_gold_coref_results = {}
79 | for doc_id, event_pairs in gold_coref_results.items():
80 | pred_event_pairs = pred_coref_results[doc_id]
81 | unrecognized_event_pairs = {}
82 | recognized_event_pairs = {}
83 | for pair_id, results in event_pairs.items():
84 | if pair_id in pred_event_pairs:
85 | recognized_event_pairs[pair_id] = results
86 | else:
87 | unrecognized_event_pairs[pair_id] = results
88 | new_gold_coref_results[doc_id] = {
89 | 'unrecognized_event_pairs': unrecognized_event_pairs,
90 | 'recognized_event_pairs': recognized_event_pairs
91 | }
92 | new_pred_coref_results = {}
93 | for doc_id, event_pairs in pred_coref_results.items():
94 | gold_event_pairs = gold_coref_results[doc_id]
95 | recognized_event_pairs = {}
96 | wrong_event_pairs = {}
97 | for pair_id, results in event_pairs.items():
98 | if pair_id in gold_event_pairs:
99 | recognized_event_pairs[pair_id] = results
100 | else:
101 | wrong_event_pairs[pair_id] = results
102 | new_pred_coref_results[doc_id] = {
103 | 'recognized_event_pairs': recognized_event_pairs,
104 | 'wrong_event_pairs': wrong_event_pairs
105 | }
106 |
107 | return new_gold_coref_results, new_pred_coref_results
108 |
--------------------------------------------------------------------------------
/src/clustering/arg.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def parse_args():
4 | parser = argparse.ArgumentParser()
5 |
6 | # Required parameters
7 | parser.add_argument("--output_dir", default=None, type=str, required=True,
8 | help="The output directory where the conll files and evaluate results will be written.",
9 | )
10 | parser.add_argument("--test_golden_filepath", default=None, type=str, required=True,
11 | help="golden test set file path.",
12 | )
13 | parser.add_argument("--test_pred_filepath", default=None, type=str, required=True,
14 | help="predicted coref file path.",
15 | )
16 | parser.add_argument("--golden_conll_filename", default=None, type=str, required=True)
17 | parser.add_argument("--pred_conll_filename", default=None, type=str, required=True)
18 |
19 | # Other parameters
20 | parser.add_argument("--do_rescore", action="store_true", help="Whether to rescoring coref value.")
21 | parser.add_argument("--rescore_reward", default=0.8, type=float, required=False)
22 | parser.add_argument("--rescore_penalty", default=0.8, type=float, required=False)
23 | parser.add_argument("--do_evaluate", action="store_true", help="Whether to evaluate conll files.")
24 |
25 | args = parser.parse_args()
26 | return args
27 |
--------------------------------------------------------------------------------
/src/clustering/cluster.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict, defaultdict
2 |
3 | def clustering_greedy(events, pred_labels:list):
4 | '''
5 | As long as there is a pair of events coreference
6 | between any two event chains, merge them.
7 | '''
8 | def need_merge(set_1, set_2, coref_event_pair_set):
9 | for e1 in set_1:
10 | for e2 in set_2:
11 | if f'{e1}-{e2}' in coref_event_pair_set:
12 | return True
13 | return False
14 |
15 | def find_merge_position(cluster_list, coref_event_pairs):
16 | for i in range(len(cluster_list) - 1):
17 | for j in range(i + 1, len(cluster_list)):
18 | if need_merge(cluster_list[i], cluster_list[j], coref_event_pairs):
19 | return i, j
20 | return -1, -1
21 |
22 | if len(events) > 1:
23 | assert len(pred_labels) == len(events) * (len(events) - 1) / 2
24 | event_pairs = [
25 | str(events[i]['start']) + '-' + str(events[j]['start'])
26 | for i in range(len(events) - 1) for j in range(i + 1, len(events))
27 | ]
28 | coref_event_pairs = [event_pair for event_pair, pred in zip(event_pairs, pred_labels) if pred == 1]
29 | cluster_list = []
30 | for event in events: # init each link as an event
31 | cluster_list.append(set([event['start']]))
32 | while True:
33 | i, j = find_merge_position(cluster_list, coref_event_pairs)
34 | if i == -1: # no cluster can be merged
35 | break
36 | cluster_list[i] |= cluster_list[j]
37 | del cluster_list[j]
38 | return cluster_list
39 |
40 | def clustering_rescore(events, pred_labels:list, reward=0.8, penalty=0.8):
41 | event_pairs = [
42 | str(events[i]['start']) + '-' + str(events[j]['start'])
43 | for i in range(len(events) - 1) for j in range(i + 1, len(events))
44 | ]
45 | coref_event_pairs = [event_pair for event_pair, pred in zip(event_pairs, pred_labels) if pred == 1]
46 | coref = OrderedDict([(event_pair, 1 if pred == 1 else -1) for event_pair, pred in zip(event_pairs, pred_labels)])
47 | for i in range(len(events) - 1):
48 | for j in range(i + 1, len(events)):
49 | for k in range(len(events)):
50 | if k == i or k == j:
51 | continue
52 | event_i, event_j, event_k = events[i]['start'], events[j]['start'], events[k]['start']
53 | coref_i_k = (f'{event_k}-{event_i}' if k < i else f'{event_i}-{event_k}') in coref_event_pairs
54 | coref_j_k = (f'{event_k}-{event_j}' if k < j else f'{event_j}-{event_k}') in coref_event_pairs
55 | if coref_i_k and coref_j_k:
56 | coref[f'{event_i}-{event_j}'] += reward
57 | elif coref_i_k != coref_j_k:
58 | coref[f'{event_i}-{event_j}'] -= penalty
59 | coref = OrderedDict([(event_pair, score) for event_pair, score in coref.items() if score > 0])
60 | sorted_coref = sorted(coref.items(), key=lambda x:x[1], reverse=True)
61 | cluster_id = 0
62 | events_cluster_ids = {str(event['start']):-1 for event in events} # {event:cluster_id}
63 | for event_pair, _ in sorted_coref:
64 | e_i, e_j = event_pair.split('-')
65 | if events_cluster_ids[e_i] == events_cluster_ids[e_j] == -1:
66 | events_cluster_ids[e_i] = events_cluster_ids[e_j] = cluster_id
67 | cluster_id += 1
68 | elif events_cluster_ids[e_i] == -1:
69 | events_cluster_ids[e_i] = events_cluster_ids[e_j]
70 | elif events_cluster_ids[e_j] == -1:
71 | events_cluster_ids[e_j] = events_cluster_ids[e_j]
72 | for event, c_id in events_cluster_ids.items():
73 | if c_id == -1:
74 | events_cluster_ids[event] = cluster_id
75 | cluster_id += 1
76 | cluster_list = defaultdict(set)
77 | for event, c_id in events_cluster_ids.items():
78 | cluster_list[c_id].add(event)
79 | return [v for v in cluster_list.values()]
80 |
81 | def clustering(events, pred_labels:list, mode='rescore', rescore_reward=0.8, rescore_penalty=0.8):
82 | assert mode in ['greedy', 'rescore']
83 | if mode == 'rescore':
84 | return clustering_rescore(events, pred_labels, rescore_reward, rescore_penalty)
85 | elif mode == 'greedy':
86 | return clustering_greedy(events, pred_labels)
87 |
--------------------------------------------------------------------------------
/src/clustering/run_cluster.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import logging
4 | from tqdm.auto import tqdm
5 | import json
6 | import subprocess
7 | import re
8 | import sys
9 | sys.path.append('../../')
10 | from src.clustering.arg import parse_args
11 | from src.clustering.utils import create_golden_conll_file, get_pred_coref_results, create_pred_conll_file
12 | from src.clustering.cluster import clustering
13 |
14 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
15 | datefmt='%Y/%m/%d %H:%M:%S',
16 | level=logging.INFO)
17 | logger = logging.getLogger("Cluster")
18 |
19 | COREF_RESULTS_REGEX = re.compile(r".*Coreference: Recall: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\tPrecision: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\tF1: ([0-9.]+)%.*", re.DOTALL)
20 | BLANC_RESULTS_REGEX = re.compile(r".*BLANC: Recall: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\tPrecision: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\tF1: ([0-9.]+)%.*", re.DOTALL)
21 |
22 | def official_conll_eval(gold_path, predicted_path, metric, official_stdout=True):
23 | assert metric in ["muc", "bcub", "ceafe", "blanc"]
24 | cmd = ["../../reference-coreference-scorers/scorer.pl", metric, gold_path, predicted_path, "none"]
25 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE)
26 | stdout, stderr = process.communicate()
27 | process.wait()
28 |
29 | stdout = stdout.decode("utf-8")
30 | if stderr is not None:
31 | logger.error(stderr)
32 |
33 | if official_stdout:
34 | logger.info("Official result for {}".format(metric))
35 | logger.info(stdout)
36 |
37 | coref_results_match = re.match(
38 | BLANC_RESULTS_REGEX if metric == 'blanc' else COREF_RESULTS_REGEX,
39 | stdout
40 | )
41 | recall = float(coref_results_match.group(1))
42 | precision = float(coref_results_match.group(2))
43 | f1 = float(coref_results_match.group(3))
44 | return {"r": recall, "p": precision, "f": f1}
45 |
46 | if __name__ == '__main__':
47 | args = parse_args()
48 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
49 | raise ValueError(
50 | f'Output directory ({args.output_dir}) already exists and is not empty.')
51 | if not os.path.exists(args.output_dir):
52 | os.mkdir(args.output_dir)
53 | golden_conll_path = os.path.join(args.output_dir, args.golden_conll_filename)
54 | pred_conll_path = os.path.join(args.output_dir, args.pred_conll_filename)
55 |
56 | logger.info(f'creating golden conll file in {args.output_dir} ...')
57 | create_golden_conll_file(args.test_golden_filepath, golden_conll_path)
58 | # clustering
59 | # {doc_id: {'events': event_list, 'pred_labels': pred_coref_labels}}
60 | pred_coref_results = get_pred_coref_results(args.test_pred_filepath)
61 | cluster_dict = {} # {doc_id: [cluster_set_1, cluster_set_2, ...]}
62 | logger.info('clustering ...')
63 | for doc_id, pred_result in tqdm(pred_coref_results.items()):
64 | cluster_list = clustering(
65 | pred_result['events'],
66 | pred_result['pred_labels'],
67 | mode='rescore' if args.do_rescore else 'greedy',
68 | rescore_reward=args.rescore_reward,
69 | rescore_penalty=args.rescore_penalty
70 | )
71 | cluster_dict[doc_id] = cluster_list
72 | logger.info(f'saving predicted clusters in {args.output_dir} ...')
73 | create_pred_conll_file(cluster_dict, golden_conll_path, pred_conll_path)
74 | with open(os.path.join(args.output_dir, 'args.txt'), 'wt') as f:
75 | f.write(str(args))
76 | # evaluate on the conll files
77 | if args.do_evaluate:
78 | results = {
79 | m: official_conll_eval(golden_conll_path, pred_conll_path, m, official_stdout=True)
80 | for m in ("muc", "bcub", "ceafe", "blanc")
81 | }
82 | results['avg_f1'] = sum([scores['f'] for scores in results.values()]) / len(results)
83 | logger.info(results)
84 | with open(os.path.join(args.output_dir, 'evaluate_results.json'), 'wt', encoding='utf-8') as f:
85 | f.write(json.dumps(results) + '\n')
86 | shutil.rmtree(args.output_dir)
87 |
--------------------------------------------------------------------------------
/src/clustering/run_cluster.sh:
--------------------------------------------------------------------------------
1 | export OUTPUT_DIR=./TEMP/
2 |
3 | python3 run_cluster.py \
4 | --output_dir=$OUTPUT_DIR \
5 | --test_golden_filepath=../../data/test.json \
6 | --test_pred_filepath=../../data/XXX_weights.bin_test_pred_corefs.json \
7 | --golden_conll_filename=gold_test.conll \
8 | --pred_conll_filename=pred_test.conll \
9 | --do_evaluate \
10 | # --do_rescore \
11 | # --rescore_reward=0.5 \
12 | # --rescore_penalty=0.5
--------------------------------------------------------------------------------
/src/clustering/utils.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import json
3 |
4 | def get_pred_coref_results(pred_file_path, ):
5 | pred_results = {} # {doc_id: {'events': event_list, 'pred_labels': pred_coref_labels}}
6 | with open(pred_file_path, 'rt', encoding='utf-8') as f:
7 | for line in f:
8 | sample = json.loads(line.strip())
9 | pred_results[sample['doc_id']] = {
10 | 'events': sample['events'],
11 | 'pred_labels': sample['pred_label']
12 | }
13 | return pred_results
14 |
15 | def create_golden_conll_file(test_file_path, conll_file_path):
16 |
17 | def get_event_cluster_idx(event_id:str, clusters):
18 | for idx, cluster in enumerate(clusters):
19 | if event_id in cluster['events']:
20 | return idx
21 | print('ERROR!')
22 | return None
23 |
24 | with open(test_file_path, 'rt', encoding='utf-8') as f_in, \
25 | open(conll_file_path, 'wt', encoding='utf-8') as f_out:
26 | for line in f_in:
27 | sample = json.loads(line.strip())
28 | doc_id = sample['doc_id']
29 | f_out.write(f'#begin document ({doc_id});\n')
30 | clusters = sample['clusters']
31 | for event in sample['events']:
32 | cluster_idx = get_event_cluster_idx(event['event_id'], clusters)
33 | start = event['start']
34 | f_out.write(f'{doc_id}\t{start}\txxx\t({cluster_idx})\n')
35 | f_out.write('#end document\n')
36 |
37 | def create_pred_conll_file(cluster_dict:dict, golden_conll_filepath:str, conll_filepath:str, no_repeat=True):
38 | '''
39 | # Args:
40 | - cluster_dict: {doc_id: [cluster_set_1, cluster_set_2, ...]}
41 | '''
42 | new_cluster_dict = {} # {doc_id: {event: cluster_idx}}
43 | for doc_id, cluster_list in cluster_dict.items():
44 | event_cluster_idx = {} # {event: cluster_idx}
45 | for c_idx, cluster in enumerate(cluster_list):
46 | for event in cluster:
47 | event_cluster_idx[str(event)] = c_idx
48 | new_cluster_dict[doc_id] = event_cluster_idx
49 | golden_file_dic = collections.OrderedDict() # {doc_id: [event_1, event_2, ...]}
50 | with open(golden_conll_filepath, 'rt', encoding='utf-8') as f_in:
51 | for line in f_in:
52 | if line.startswith('#begin'):
53 | doc_id = line.replace('#begin document (', '').replace(');', '').strip()
54 | golden_file_dic[doc_id] = []
55 | elif line.startswith('#end document'):
56 | continue
57 | else:
58 | _, event, _, _ = line.strip().split('\t')
59 | golden_file_dic[doc_id].append(event)
60 | with open(conll_filepath, 'wt', encoding='utf-8') as f_out:
61 | for doc_id, event_list in golden_file_dic.items():
62 | event_cluster_idx = new_cluster_dict[doc_id]
63 | f_out.write('#begin document (' + doc_id + ');\n')
64 | if no_repeat:
65 | finish_events = set()
66 | for event in event_list:
67 | if event in event_cluster_idx and event not in finish_events:
68 | cluster_idx = event_cluster_idx[event]
69 | f_out.write(f'{doc_id}\t{event}\txxx\t({cluster_idx})\n')
70 | else:
71 | f_out.write(f'{doc_id}\tnull\tnull\tnull\n')
72 | finish_events.add(event)
73 | else:
74 | for event in event_list:
75 | if event in event_cluster_idx:
76 | cluster_idx = event_cluster_idx[event]
77 | f_out.write(f'{doc_id}\t{event}\txxx\t({cluster_idx})\n')
78 | else:
79 | f_out.write(f'{doc_id}\tnull\tnull\tnull\n')
80 | for event, cluster_idx in event_cluster_idx.items():
81 | if event in event_list:
82 | continue
83 | f_out.write(f'{doc_id}\t{event}\txxx\t({cluster_idx})\n')
84 | f_out.write('#end document\n')
85 |
--------------------------------------------------------------------------------
/src/global_event_coref/analysis.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | WRONG_TYPE = {
4 | 0: 'recognize_non-coref_as_coref',
5 | 1: 'recognize_coref_as_non-coref'
6 | }
7 |
8 | def get_pretty_event(sentences, sent_idx, sent_start, trigger, context=1):
9 | before = ' '.join([sent['text'] for sent in sentences[max(0, sent_idx-context):sent_idx]]).strip()
10 | after = ' '.join([sent['text'] for sent in sentences[sent_idx+1:min(len(sentences), sent_idx+context+1)]]).strip()
11 | event_mention = sentences[sent_idx]['text']
12 | sent = event_mention[:sent_start] + '#####' + trigger + '#####' + event_mention[sent_start + len(trigger):]
13 | return before + ' ' + sent + ' ' + after
14 |
15 | def find_event_by_start(events, offset):
16 | for event in events:
17 | if event['start'] == offset:
18 | return event
19 | return None
20 |
21 | def get_coref_answer(clusters, e1_id, e2_id):
22 | for cluster in clusters:
23 | events = cluster['events']
24 | if e1_id in events and e2_id in events:
25 | return 1
26 | elif e1_id in events or e2_id in events:
27 | return 0
28 | return 0
29 |
30 | def get_wrong_samples(doc_id, new_events, predictions, source_events, clusters, sentences, pred_event_filepath):
31 | wrong_1_list, wrong_2_list = [], []
32 |
33 | pred_event_dict = {}
34 | with open(pred_event_filepath, 'rt' , encoding='utf-8') as f_in:
35 | for line in f_in.readlines():
36 | sample = json.loads(line.strip())
37 | pred_event_dict[sample['doc_id']] = [event['start'] for event in sample['pred_label']]
38 |
39 | idx = 0
40 | true_labels = []
41 | for i in range(len(new_events) - 1):
42 | for j in range(i + 1, len(new_events)):
43 | e1_start, e2_start = new_events[i][0], new_events[j][0]
44 | if e1_start not in pred_event_dict[doc_id] or e2_start not in pred_event_dict[doc_id]:
45 | idx += 1
46 | continue
47 | e1 = find_event_by_start(source_events, e1_start)
48 | e2 = find_event_by_start(source_events, e2_start)
49 | pred_coref = predictions[idx]
50 | idx += 1
51 | true_coref = get_coref_answer(clusters, e1['event_id'], e2['event_id'])
52 | true_labels.append(true_coref)
53 | if pred_coref == true_coref:
54 | continue
55 | pretty_e1 = get_pretty_event(sentences, e1['sent_idx'], e1['sent_start'], e1['trigger'])
56 | pretty_e2 = get_pretty_event(sentences, e2['sent_idx'], e2['sent_start'], e2['trigger'])
57 | if pred_coref == 1:
58 | wrong_1_list.append({
59 | 'doc_id': doc_id,
60 | 'e1_start': e1_start,
61 | 'e2_start': e2_start,
62 | 'e1_info': pretty_e1,
63 | 'e2_info': pretty_e2,
64 | 'wrong_type': 0
65 | })
66 | else:
67 | wrong_2_list.append({
68 | 'doc_id': doc_id,
69 | 'e1_start': e1_start,
70 | 'e2_start': e2_start,
71 | 'e1_info': pretty_e1,
72 | 'e2_info': pretty_e2,
73 | 'wrong_type': 1
74 | })
75 | return wrong_1_list, wrong_2_list, true_labels
--------------------------------------------------------------------------------
/src/global_event_coref/arg.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def parse_args():
4 | parser = argparse.ArgumentParser()
5 |
6 | # Required parameters
7 | parser.add_argument("--output_dir", default=None, type=str, required=True,
8 | help="The output directory where the model checkpoints and predictions will be written.",
9 | )
10 | parser.add_argument("--train_file", default=None, type=str, required=True, help="The input training file.")
11 | parser.add_argument("--dev_file", default=None, type=str, required=True, help="The input evaluation file.")
12 | parser.add_argument("--test_file", default=None, type=str, required=True, help="The input testing file.")
13 |
14 | parser.add_argument("--model_type",
15 | default="longformer", type=str, required=True
16 | )
17 | parser.add_argument("--model_checkpoint",
18 | default="allenai/longformer-base-4096", type=str, required=True,
19 | help="Path to pretrained model or model identifier from huggingface.co/models",
20 | )
21 | parser.add_argument("--max_seq_length", default=4096, type=int, required=True)
22 | parser.add_argument("--matching_style", default="multi", type=str, required=True,
23 | help="how to match two event representations"
24 | )
25 |
26 | parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
27 | parser.add_argument("--do_test", action="store_true", help="Whether to run eval on the test set.")
28 | parser.add_argument("--do_predict", action="store_true", help="Whether to save predicted labels.")
29 | parser.add_argument("--do_analysis", action="store_true", help="Whether to do analysis on the test set.")
30 |
31 | # Other parameters
32 | parser.add_argument("--cache_dir", default=None, type=str,
33 | help="Where do you want to store the pre-trained models downloaded from s3"
34 | )
35 | parser.add_argument("--topic_model", default='stm', type=str,
36 | choices=['stm', 'stm_bn', 'vmf']
37 | )
38 | parser.add_argument("--topic_dim", default=32, type=int)
39 | parser.add_argument("--topic_inter_map", default=64, type=int)
40 | parser.add_argument("--mention_encoder_type", default="bert", type=str)
41 | parser.add_argument("--mention_encoder_checkpoint",
42 | default="bert-large-cased", type=str,
43 | help="Path to pretrained model or model identifier from huggingface.co/models",
44 | )
45 | parser.add_argument("--include_mention_context", action="store_true")
46 | parser.add_argument("--max_mention_length", default=512, type=int)
47 | parser.add_argument("--add_contrastive_loss", action="store_true")
48 | parser.add_argument("--softmax_loss", default='ce', type=str,
49 | help="The loss function for softmax model.",
50 | choices=['lsr', 'focal', 'ce']
51 | )
52 |
53 | parser.add_argument("--learning_rate", default=1e-5, type=float, help="The initial learning rate for Adam.")
54 | parser.add_argument("--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform.")
55 | parser.add_argument("--batch_size", default=4, type=int)
56 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
57 |
58 | parser.add_argument("--adam_beta1", default=0.9, type=float,
59 | help="Epsilon for Adam optimizer."
60 | )
61 | parser.add_argument("--adam_beta2", default=0.98, type=float,
62 | help="Epsilon for Adam optimizer."
63 | )
64 | parser.add_argument("--adam_epsilon", default=1e-8, type=float,
65 | help="Epsilon for Adam optimizer."
66 | )
67 | parser.add_argument("--warmup_proportion", default=0.1, type=float,
68 | help="Proportion of training to perform linear learning rate warmup for,E.g., 0.1 = 10% of training."
69 | )
70 | parser.add_argument("--weight_decay", default=0.01, type=float,
71 | help="Weight decay if we apply some."
72 | )
73 | args = parser.parse_args()
74 | return args
--------------------------------------------------------------------------------
/src/global_event_coref/run_global_base.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import torch
4 | import json
5 | from tqdm.auto import tqdm
6 | from transformers import AutoConfig, AutoTokenizer
7 | from transformers import AdamW, get_scheduler
8 | import numpy as np
9 | from sklearn.metrics import classification_report
10 | import sys
11 | sys.path.append('../../')
12 | from src.tools import seed_everything, NpEncoder
13 | from src.global_event_coref.arg import parse_args
14 | from src.global_event_coref.data import KBPCoref, get_dataLoader
15 | from src.global_event_coref.modeling import LongformerSoftmaxForEC
16 | from src.global_event_coref.analysis import get_wrong_samples, WRONG_TYPE
17 |
18 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
19 | datefmt='%Y/%m/%d %H:%M:%S',
20 | level=logging.INFO)
21 | logger = logging.getLogger("Model")
22 |
23 | def to_device(args, batch_data):
24 | new_batch_data = {}
25 | for k, v in batch_data.items():
26 | if k in ['batch_events', 'batch_event_cluster_ids']:
27 | new_batch_data[k] = v
28 | elif k == 'batch_inputs':
29 | new_batch_data[k] = {
30 | k_: v_.to(args.device) for k_, v_ in v.items()
31 | }
32 | else:
33 | raise ValueError(f'Unknown batch data key: {k}')
34 | return new_batch_data
35 |
36 | def train_loop(args, dataloader, model, optimizer, lr_scheduler, epoch, total_loss):
37 | progress_bar = tqdm(range(len(dataloader)))
38 | progress_bar.set_description(f'loss: {0:>7f}')
39 | finish_step_num = epoch * len(dataloader)
40 |
41 | model.train()
42 | for step, batch_data in enumerate(dataloader, start=1):
43 | batch_data = to_device(args, batch_data)
44 | outputs = model(**batch_data)
45 | loss = outputs[0]
46 |
47 | if loss:
48 | optimizer.zero_grad()
49 | loss.backward()
50 | optimizer.step()
51 | lr_scheduler.step()
52 |
53 | total_loss += loss.item() if loss else 0.
54 | progress_bar.set_description(f'loss: {total_loss/(finish_step_num + step):>7f}')
55 | progress_bar.update(1)
56 | return total_loss
57 |
58 | def test_loop(args, dataloader, model):
59 | true_labels, true_predictions = [], []
60 | model.eval()
61 | with torch.no_grad():
62 | for batch_data in tqdm(dataloader):
63 | batch_data = to_device(args, batch_data)
64 | outputs = model(**batch_data)
65 | _, logits, masks, labels = outputs
66 |
67 | predictions = logits.argmax(dim=-1).cpu().numpy() # [batch, event_pair_num]
68 | y = labels.cpu().numpy()
69 | lens = np.sum(masks.cpu().numpy(), axis=-1)
70 | true_labels += [
71 | int(l) for label, seq_len in zip(y, lens) for idx, l in enumerate(label) if idx < seq_len
72 | ]
73 | true_predictions += [
74 | int(p) for pred, seq_len in zip(predictions, lens) for idx, p in enumerate(pred) if idx < seq_len
75 | ]
76 | return classification_report(true_labels, true_predictions, output_dict=True)
77 |
78 | def train(args, train_dataset, dev_dataset, model, tokenizer):
79 | """ Train the model """
80 | train_dataloader = get_dataLoader(args, train_dataset, tokenizer, shuffle=True)
81 | dev_dataloader = get_dataLoader(args, dev_dataset, tokenizer, shuffle=False)
82 | t_total = len(train_dataloader) * args.num_train_epochs
83 | # Prepare optimizer and schedule (linear warmup and decay)
84 | no_decay = ["bias", "LayerNorm.weight"]
85 | optimizer_grouped_parameters = [
86 | {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay},
87 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}
88 | ]
89 | args.warmup_steps = int(t_total * args.warmup_proportion)
90 | optimizer = AdamW(
91 | optimizer_grouped_parameters,
92 | lr=args.learning_rate,
93 | betas=(args.adam_beta1, args.adam_beta2),
94 | eps=args.adam_epsilon
95 | )
96 | lr_scheduler = get_scheduler(
97 | 'linear',
98 | optimizer,
99 | num_warmup_steps=args.warmup_steps,
100 | num_training_steps=t_total
101 | )
102 | # Train!
103 | logger.info("***** Running training *****")
104 | logger.info(f"Num examples - {len(train_dataset)}")
105 | logger.info(f"Num Epochs - {args.num_train_epochs}")
106 | logger.info(f"Total optimization steps - {t_total}")
107 | with open(os.path.join(args.output_dir, 'args.txt'), 'wt') as f:
108 | f.write(str(args))
109 |
110 | total_loss = 0.
111 | best_f1 = 0.
112 | for epoch in range(args.num_train_epochs):
113 | print(f"Epoch {epoch+1}/{args.num_train_epochs}\n-------------------------------")
114 | total_loss = train_loop(args, train_dataloader, model, optimizer, lr_scheduler, epoch, total_loss)
115 | metrics = test_loop(args, dev_dataloader, model)
116 | dev_p, dev_r, dev_f1 = metrics['1']['precision'], metrics['1']['recall'], metrics['1']['f1-score']
117 | logger.info(f'Dev: P - {(100*dev_p):0.4f} R - {(100*dev_r):0.4f} F1 - {(100*dev_f1):0.4f}')
118 | if dev_f1 > best_f1:
119 | best_f1 = dev_f1
120 | logger.info(f'saving new weights to {args.output_dir}...\n')
121 | save_weight = f'epoch_{epoch+1}_dev_f1_{(100*dev_f1):0.4f}_weights.bin'
122 | torch.save(model.state_dict(), os.path.join(args.output_dir, save_weight))
123 | elif 100 * dev_p > 69 and 100 * dev_r > 69:
124 | logger.info(f'saving new weights to {args.output_dir}...\n')
125 | save_weight = f'epoch_{epoch+1}_dev_f1_{(100*dev_f1):0.4f}_weights.bin'
126 | torch.save(model.state_dict(), os.path.join(args.output_dir, save_weight))
127 | with open(os.path.join(args.output_dir, 'dev_metrics.txt'), 'at') as f:
128 | f.write(f'epoch_{epoch+1}\n' + json.dumps(metrics, cls=NpEncoder) + '\n\n')
129 | logger.info("Done!")
130 |
131 | def predict(args, document:str, events:list, model, tokenizer):
132 | '''
133 | # Args:
134 | - events: [
135 | [e_char_start, e_char_end], ...
136 | ], document[e1_char_start:e1_char_end + 1] = trigger1
137 | '''
138 | inputs = tokenizer(
139 | document,
140 | max_length=args.max_seq_length,
141 | truncation=True,
142 | return_tensors="pt"
143 | )
144 | filtered_events = []
145 | new_events = []
146 | for event in events:
147 | char_start, char_end = event
148 | token_start = inputs.char_to_token(char_start)
149 | if not token_start:
150 | token_start = inputs.char_to_token(char_start + 1)
151 | token_end = inputs.char_to_token(char_end)
152 | if not token_start or not token_end:
153 | continue
154 | filtered_events.append([token_start, token_end])
155 | new_events.append(event)
156 | if not new_events:
157 | return [], [], []
158 | inputs = {
159 | 'batch_inputs': inputs,
160 | 'batch_events': [filtered_events]
161 | }
162 | inputs = to_device(args, inputs)
163 | with torch.no_grad():
164 | outputs = model(**inputs)
165 | logits = outputs[1]
166 | predictions = logits.argmax(dim=-1)[0].cpu().numpy().tolist()
167 | probabilities = torch.nn.functional.softmax(logits, dim=-1)[0].cpu().numpy().tolist()
168 | probabilities = [probabilities[idx][pred] for idx, pred in enumerate(predictions)]
169 | if len(new_events) > 1:
170 | assert len(predictions) == len(new_events) * (len(new_events) - 1) / 2
171 | return new_events, predictions, probabilities
172 |
173 | def test(args, test_dataset, model, tokenizer, save_weights:list):
174 | test_dataloader = get_dataLoader(args, test_dataset, tokenizer, batch_size=1, shuffle=False)
175 | logger.info('***** Running testing *****')
176 | for save_weight in save_weights:
177 | logger.info(f'loading weights from {save_weight}...')
178 | model.load_state_dict(torch.load(os.path.join(args.output_dir, save_weight)))
179 | metrics = test_loop(args, test_dataloader, model)
180 | with open(os.path.join(args.output_dir, 'test_metrics.txt'), 'at') as f:
181 | f.write(f'{save_weight}\n{json.dumps(metrics, cls=NpEncoder)}\n\n')
182 |
183 | if __name__ == '__main__':
184 | args = parse_args()
185 | if args.do_train and os.path.exists(args.output_dir) and os.listdir(args.output_dir):
186 | raise ValueError(f'Output directory ({args.output_dir}) already exists and is not empty.')
187 | if not os.path.exists(args.output_dir):
188 | os.mkdir(args.output_dir)
189 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
190 | args.n_gpu = torch.cuda.device_count()
191 | logger.warning(f'Using {args.device} device, n_gpu: {args.n_gpu}')
192 | # Set seed
193 | seed_everything(args.seed)
194 | # Load pretrained model and tokenizer
195 | logger.info(f'using model {"with" if args.add_contrastive_loss else "without"} Contrastive loss')
196 | logger.info(f'loading pretrained model and tokenizer of {args.model_type} ...')
197 | config = AutoConfig.from_pretrained(args.model_checkpoint, cache_dir=args.cache_dir, )
198 | tokenizer = AutoTokenizer.from_pretrained(args.model_checkpoint, cache_dir=args.cache_dir)
199 | args.num_labels = 2
200 | model = LongformerSoftmaxForEC.from_pretrained(
201 | args.model_checkpoint,
202 | config=config,
203 | cache_dir=args.cache_dir,
204 | args=args
205 | ).to(args.device)
206 | # Training
207 | if args.do_train:
208 | train_dataset = KBPCoref(args.train_file)
209 | dev_dataset = KBPCoref(args.dev_file)
210 | train(args, train_dataset, dev_dataset, model, tokenizer)
211 | # Testing
212 | save_weights = [file for file in os.listdir(args.output_dir) if file.endswith('.bin')]
213 | if args.do_test:
214 | test_dataset = KBPCoref(args.test_file)
215 | test(args, test_dataset, model, tokenizer, save_weights)
216 | # Predicting
217 | if args.do_predict:
218 | pred_event_file = 'epoch_3_dev_f1_57.9994_weights.bin_test_pred_events.json'
219 | # pred_event_file = 'test_filtered.json'
220 |
221 | for best_save_weight in save_weights:
222 | logger.info(f'loading weights from {best_save_weight}...')
223 | model.load_state_dict(torch.load(os.path.join(args.output_dir, best_save_weight)))
224 | logger.info(f'predicting coref labels of {best_save_weight}...')
225 | results = []
226 | model.eval()
227 | with open(os.path.join(args.output_dir, pred_event_file), 'rt' , encoding='utf-8') as f_in:
228 | for line in tqdm(f_in.readlines()):
229 | sample = json.loads(line.strip())
230 | events_from_file = sample['events'] if pred_event_file == 'test_filtered.json' else sample['pred_label']
231 | events = [
232 | [event['start'], event['start'] + len(event['trigger']) - 1]
233 | for event in events_from_file
234 | ]
235 | new_events, predictions, probabilities = predict(args, sample['document'], events, model, tokenizer)
236 | results.append({
237 | "doc_id": sample['doc_id'],
238 | "document": sample['document'],
239 | "events": [
240 | {
241 | 'start': char_start,
242 | 'end': char_end,
243 | 'trigger': sample['document'][char_start:char_end+1]
244 | } for char_start, char_end in new_events
245 | ],
246 | "pred_label": predictions,
247 | "pred_prob": probabilities
248 | })
249 | save_name = '_gold_test_pred_corefs.json' if pred_event_file == 'test_filtered.json' else '_test_pred_corefs.json'
250 | with open(os.path.join(args.output_dir, best_save_weight + save_name), 'wt', encoding='utf-8') as f:
251 | for exapmle_result in results:
252 | f.write(json.dumps(exapmle_result) + '\n')
253 | # Analysis
254 | if args.do_analysis:
255 | pred_event_file = 'epoch_3_dev_f1_57.9994_weights.bin_test_pred_events.json'
256 | pred_event_filepath = os.path.join(args.output_dir, pred_event_file)
257 |
258 | analysis_weight = 'XXX_weights.bin'
259 | logger.info(f'loading weights from {analysis_weight}...')
260 | model.load_state_dict(torch.load(os.path.join(args.output_dir, analysis_weight)))
261 | logger.info(f'predicting coref labels of {analysis_weight}...')
262 | all_wrong_1, all_wrong_2 = [], []
263 | all_predictions, all_labels = [], []
264 | model.eval()
265 | with open(os.path.join(args.test_file), 'rt' , encoding='utf-8') as f_in:
266 | for line in tqdm(f_in.readlines()):
267 | sample = json.loads(line.strip())
268 | events = [
269 | [event['start'], event['start'] + len(event['trigger']) - 1]
270 | for event in sample['events']
271 | ]
272 | new_events, predictions, _ = predict(args, sample['document'], events, model, tokenizer)
273 | all_predictions += predictions
274 | wrong_1_list, wrong_2_list, true_labels = get_wrong_samples(
275 | sample['doc_id'],
276 | new_events, predictions,
277 | sample['events'], sample['clusters'], sample['sentences'],
278 | pred_event_filepath
279 | )
280 | all_labels += true_labels
281 | all_wrong_1 += wrong_1_list
282 | all_wrong_2 += wrong_2_list
283 | assert len(all_labels) == len(all_predictions)
284 | print(classification_report(all_labels, all_predictions))
285 | print(f'all_wrong_1: {len(all_wrong_1)}\tall_wrong_2: {len(all_wrong_2)}')
286 | with open(os.path.join(args.output_dir, analysis_weight + '_' + WRONG_TYPE[0] + '.json'), 'wt', encoding='utf-8') as f_out_1:
287 | f_out_1.write(json.dumps(all_wrong_1) + '\n')
288 | with open(os.path.join(args.output_dir, analysis_weight + '_' + WRONG_TYPE[1] + '.json'), 'wt', encoding='utf-8') as f_out_2:
289 | f_out_2.write(json.dumps(all_wrong_2) + '\n')
290 |
--------------------------------------------------------------------------------
/src/global_event_coref/run_global_base.sh:
--------------------------------------------------------------------------------
1 | export OUTPUT_DIR=./M-multi-cosine_results/
2 |
3 | python3 run_global_base.py \
4 | --output_dir=$OUTPUT_DIR \
5 | --model_type=longformer \
6 | --model_checkpoint=../../../PT_MODELS/allenai/longformer-large-4096/ \
7 | --train_file=../../data/train_filtered.json \
8 | --dev_file=../../data/dev_filtered.json \
9 | --test_file=../../data/test_filtered.json \
10 | --max_seq_length=4096 \
11 | --learning_rate=1e-5 \
12 | --add_contrastive_loss \
13 | --matching_style=multi_cosine \
14 | --softmax_loss=ce \
15 | --num_train_epochs=30 \
16 | --batch_size=1 \
17 | --do_train \
18 | --warmup_proportion=0. \
19 | --seed=42
--------------------------------------------------------------------------------
/src/global_event_coref/run_global_base_with_mask.sh:
--------------------------------------------------------------------------------
1 | export OUTPUT_DIR=./Mask_M-multi-cosine_closs_results/
2 |
3 | python3 run_global_base_with_mask.py \
4 | --output_dir=$OUTPUT_DIR \
5 | --model_type=longformer \
6 | --model_checkpoint=../../../PT_MODELS/allenai/longformer-large-4096/ \
7 | --mention_encoder_type=bert \
8 | --mention_encoder_checkpoint=../../../PT_MODELS/bert-base-cased/ \
9 | --train_file=../../data/train_filtered.json \
10 | --dev_file=../../data/dev_filtered.json \
11 | --test_file=../../data/test_filtered.json \
12 | --max_seq_length=4096 \
13 | --max_mention_length=256 \
14 | --learning_rate=1e-5 \
15 | --add_contrastive_loss \
16 | --matching_style=multi_cosine \
17 | --softmax_loss=ce \
18 | --num_train_epochs=50 \
19 | --batch_size=1 \
20 | --do_train \
21 | --warmup_proportion=0. \
22 | --seed=42
--------------------------------------------------------------------------------
/src/global_event_coref/run_global_base_with_mask_topic.sh:
--------------------------------------------------------------------------------
1 | export OUTPUT_DIR=./MaskTopic_M-multi-cosine_closs_results/
2 |
3 | python3 run_global_base_with_mask_topic.py \
4 | --output_dir=$OUTPUT_DIR \
5 | --model_type=longformer \
6 | --model_checkpoint=../../../PT_MODELS/allenai/longformer-large-4096/ \
7 | --mention_encoder_type=bert \
8 | --mention_encoder_checkpoint=../../../PT_MODELS/bert-base-cased/ \
9 | --topic_model=vmf \
10 | --topic_dim=32 \
11 | --topic_inter_map=64 \
12 | --train_file=../../data/train_filtered.json \
13 | --dev_file=../../data/dev_filtered.json \
14 | --test_file=../../data/test_filtered.json \
15 | --max_seq_length=4096 \
16 | --max_mention_length=256 \
17 | --learning_rate=1e-5 \
18 | --add_contrastive_loss \
19 | --matching_style=multi_cosine \
20 | --softmax_loss=ce \
21 | --num_train_epochs=50 \
22 | --batch_size=1 \
23 | --do_train \
24 | --warmup_proportion=0. \
25 | --seed=42
--------------------------------------------------------------------------------
/src/global_event_coref/run_global_base_with_topic.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from tqdm.auto import tqdm
3 | import json
4 | from collections import defaultdict, namedtuple
5 | import torch
6 | from transformers import AdamW, get_scheduler
7 | from transformers import AutoConfig, AutoTokenizer
8 | import numpy as np
9 | from sklearn.metrics import classification_report
10 | import os
11 | import sys
12 | sys.path.append('../../')
13 | from src.tools import seed_everything, NpEncoder
14 | from src.global_event_coref.arg import parse_args
15 | from src.global_event_coref.data import KBPCoref, get_dataLoader, vocab, VOCAB_SIZE
16 | from src.global_event_coref.modeling import LongformerSoftmaxForECwithTopic
17 |
18 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
19 | datefmt='%Y/%m/%d %H:%M:%S',
20 | level=logging.INFO)
21 | logger = logging.getLogger("Model")
22 | Sentence = namedtuple("Sentence", ["start", "text"])
23 |
24 | def to_device(args, batch_data):
25 | new_batch_data = {}
26 | for k, v in batch_data.items():
27 | if k in ['batch_events', 'batch_event_cluster_ids']:
28 | new_batch_data[k] = v
29 | elif k == 'batch_event_dists':
30 | new_batch_data[k] = [
31 | torch.tensor(event_dists, dtype=torch.float32).to(args.device)
32 | for event_dists in v
33 | ]
34 | elif k == 'batch_inputs':
35 | new_batch_data[k] = {
36 | k_: v_.to(args.device) for k_, v_ in v.items()
37 | }
38 | else:
39 | raise ValueError(f'Unknown batch data key: {k}')
40 | return new_batch_data
41 |
42 | def train_loop(args, dataloader, model, optimizer, lr_scheduler, epoch, total_loss):
43 | progress_bar = tqdm(range(len(dataloader)))
44 | progress_bar.set_description(f'loss: {0:>7f}')
45 | finish_step_num = epoch * len(dataloader)
46 |
47 | model.train()
48 | for step, batch_data in enumerate(dataloader, start=1):
49 | batch_data = to_device(args, batch_data)
50 | outputs = model(**batch_data)
51 | loss = outputs[0]
52 |
53 | if loss:
54 | optimizer.zero_grad()
55 | loss.backward()
56 | optimizer.step()
57 | lr_scheduler.step()
58 |
59 | total_loss += loss.item() if loss else 0.
60 | progress_bar.set_description(f'loss: {total_loss/(finish_step_num + step):>7f}')
61 | progress_bar.update(1)
62 | return total_loss
63 |
64 | def test_loop(args, dataloader, model):
65 | true_labels, true_predictions = [], []
66 | model.eval()
67 | with torch.no_grad():
68 | for batch_data in tqdm(dataloader):
69 | batch_data = to_device(args, batch_data)
70 | outputs = model(**batch_data)
71 | _, logits, masks, labels = outputs
72 |
73 | predictions = logits.argmax(dim=-1).cpu().numpy() # [batch, event_pair_num]
74 | y = labels.cpu().numpy()
75 | lens = np.sum(masks.cpu().numpy(), axis=-1)
76 | true_labels += [
77 | int(l) for label, seq_len in zip(y, lens) for idx, l in enumerate(label) if idx < seq_len
78 | ]
79 | true_predictions += [
80 | int(p) for pred, seq_len in zip(predictions, lens) for idx, p in enumerate(pred) if idx < seq_len
81 | ]
82 | return classification_report(true_labels, true_predictions, output_dict=True)
83 |
84 | def train(args, train_dataset, dev_dataset, model, tokenizer):
85 | """ Train the model """
86 | train_dataloader = get_dataLoader(args, train_dataset, tokenizer, shuffle=True, collote_fn_type='with_dist')
87 | dev_dataloader = get_dataLoader(args, dev_dataset, tokenizer, shuffle=False, collote_fn_type='with_dist')
88 | t_total = len(train_dataloader) * args.num_train_epochs
89 | # Prepare optimizer and schedule (linear warmup and decay)
90 | no_decay = ["bias", "LayerNorm.weight"]
91 | optimizer_grouped_parameters = [
92 | {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay},
93 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}
94 | ]
95 | args.warmup_steps = int(t_total * args.warmup_proportion)
96 | optimizer = AdamW(
97 | optimizer_grouped_parameters,
98 | lr=args.learning_rate,
99 | betas=(args.adam_beta1, args.adam_beta2),
100 | eps=args.adam_epsilon
101 | )
102 | lr_scheduler = get_scheduler(
103 | 'linear',
104 | optimizer,
105 | num_warmup_steps=args.warmup_steps,
106 | num_training_steps=t_total
107 | )
108 | # Train!
109 | logger.info("***** Running training *****")
110 | logger.info(f"Num examples - {len(train_dataset)}")
111 | logger.info(f"Num Epochs - {args.num_train_epochs}")
112 | logger.info(f"Total optimization steps - {t_total}")
113 | with open(os.path.join(args.output_dir, 'args.txt'), 'wt') as f:
114 | f.write(str(args))
115 |
116 | total_loss = 0.
117 | best_f1 = 0.
118 | for epoch in range(args.num_train_epochs):
119 | print(f"Epoch {epoch+1}/{args.num_train_epochs}\n-------------------------------")
120 | total_loss = train_loop(args, train_dataloader, model, optimizer, lr_scheduler, epoch, total_loss)
121 | metrics = test_loop(args, dev_dataloader, model)
122 | dev_p, dev_r, dev_f1 = metrics['1']['precision'], metrics['1']['recall'], metrics['1']['f1-score']
123 | logger.info(f'Dev: P - {(100*dev_p):0.4f} R - {(100*dev_r):0.4f} F1 - {(100*dev_f1):0.4f}')
124 | if dev_f1 > best_f1:
125 | best_f1 = dev_f1
126 | logger.info(f'saving new weights to {args.output_dir}...\n')
127 | save_weight = f'epoch_{epoch+1}_dev_f1_{(100*dev_f1):0.4f}_weights.bin'
128 | torch.save(model.state_dict(), os.path.join(args.output_dir, save_weight))
129 | elif 100 * dev_p > 69 and 100 * dev_r > 69:
130 | logger.info(f'saving new weights to {args.output_dir}...\n')
131 | save_weight = f'epoch_{epoch+1}_dev_f1_{(100*dev_f1):0.4f}_weights.bin'
132 | torch.save(model.state_dict(), os.path.join(args.output_dir, save_weight))
133 | with open(os.path.join(args.output_dir, 'dev_metrics.txt'), 'at') as f:
134 | f.write(f'epoch_{epoch+1}\n' + json.dumps(metrics, cls=NpEncoder) + '\n\n')
135 | logger.info("Done!")
136 |
137 | def predict(args, document:str, events:list, event_dists:list, model, tokenizer):
138 | assert len(events) == len(event_dists)
139 | inputs = tokenizer(
140 | document,
141 | max_length=args.max_seq_length,
142 | truncation=True,
143 | return_tensors="pt"
144 | )
145 | filtered_events = []
146 | new_events = []
147 | filtered_dists = []
148 | for event, event_dist in zip(events, event_dists):
149 | char_start, char_end = event
150 | token_start = inputs.char_to_token(char_start)
151 | if not token_start:
152 | token_start = inputs.char_to_token(char_start + 1)
153 | token_end = inputs.char_to_token(char_end)
154 | if not token_start or not token_end:
155 | continue
156 | filtered_events.append([token_start, token_end])
157 | new_events.append(event)
158 | filtered_dists.append(event_dist)
159 | if not new_events:
160 | return [], [], []
161 | inputs = {
162 | 'batch_inputs': inputs,
163 | 'batch_events': [filtered_events],
164 | 'batch_event_dists': [np.asarray(filtered_dists)]
165 | }
166 | inputs = to_device(args, inputs)
167 | with torch.no_grad():
168 | outputs = model(**inputs)
169 | logits = outputs[1]
170 | predictions = logits.argmax(dim=-1)[0].cpu().numpy().tolist()
171 | probabilities = torch.nn.functional.softmax(logits, dim=-1)[0].cpu().numpy().tolist()
172 | probabilities = [probabilities[idx][pred] for idx, pred in enumerate(predictions)]
173 | if len(new_events) > 1:
174 | assert len(predictions) == len(new_events) * (len(new_events) - 1) / 2
175 | return new_events, predictions, probabilities
176 |
177 | def test(args, test_dataset, model, tokenizer, save_weights:list):
178 | test_dataloader = get_dataLoader(
179 | args, test_dataset, tokenizer, batch_size=1, shuffle=False,
180 | collote_fn_type='with_dist'
181 | )
182 | logger.info('***** Running testing *****')
183 | for save_weight in save_weights:
184 | logger.info(f'loading weights from {save_weight}...')
185 | model.load_state_dict(torch.load(os.path.join(args.output_dir, save_weight)))
186 | metrics = test_loop(args, test_dataloader, model)
187 | with open(os.path.join(args.output_dir, 'test_metrics.txt'), 'at') as f:
188 | f.write(f'{save_weight}\n{json.dumps(metrics, cls=NpEncoder)}\n\n')
189 |
190 | def get_event_dist(e_start, e_end, sents):
191 | for s_idx, sent in enumerate(sents):
192 | sent_end = sent.start + len(sent.text) - 1
193 | if e_start >= sent.start and e_end <= sent_end:
194 | before = sents[s_idx - 1].text if s_idx > 0 else ''
195 | after = sents[s_idx + 1].text if s_idx < len(sents) - 1 else ''
196 | event_mention = before + (' ' if len(before) > 0 else '') + sent.text + ' ' + after
197 | event_mention = event_mention.lower()
198 | return [1 if w in event_mention else 0 for w in vocab]
199 | return None
200 |
201 | if __name__ == '__main__':
202 | args = parse_args()
203 | if args.do_train and os.path.exists(args.output_dir) and os.listdir(args.output_dir):
204 | raise ValueError(f'Output directory ({args.output_dir}) already exists and is not empty.')
205 | if not os.path.exists(args.output_dir):
206 | os.mkdir(args.output_dir)
207 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
208 | args.n_gpu = torch.cuda.device_count()
209 | logger.warning(f'Using {args.device} device, n_gpu: {args.n_gpu}')
210 | # Set seed
211 | seed_everything(args.seed)
212 | # Load pretrained model and tokenizer
213 | logger.info(f'using model {"with" if args.add_contrastive_loss else "without"} Contrastive loss')
214 | logger.info(f'loading pretrained model and tokenizer of {args.model_type} ...')
215 | main_config = AutoConfig.from_pretrained(args.model_checkpoint, cache_dir=args.cache_dir)
216 | tokenizer = AutoTokenizer.from_pretrained(args.model_checkpoint, cache_dir=args.cache_dir)
217 | args.num_labels = 2
218 | args.dist_dim = VOCAB_SIZE
219 | model = LongformerSoftmaxForECwithTopic.from_pretrained(
220 | args.model_checkpoint,
221 | config=main_config,
222 | cache_dir=args.cache_dir,
223 | args=args
224 | ).to(args.device)
225 | # Training
226 | save_weights = []
227 | if args.do_train:
228 | train_dataset = KBPCoref(args.train_file)
229 | dev_dataset = KBPCoref(args.dev_file)
230 | train(args, train_dataset, dev_dataset, model, tokenizer)
231 | # Testing
232 | save_weights = [file for file in os.listdir(args.output_dir) if file.endswith('.bin')]
233 | if args.do_test:
234 | test_dataset = KBPCoref(args.test_file)
235 | test(args, test_dataset, model, tokenizer, save_weights)
236 | # Predicting
237 | if args.do_predict:
238 | kbp_sent_dic = defaultdict(list) # {filename: [Sentence]}
239 | with open(os.path.join('../../data/kbp_sent.txt'), 'rt', encoding='utf-8') as sents:
240 | for line in sents:
241 | doc_id, start, text = line.strip().split('\t')
242 | kbp_sent_dic[doc_id].append(Sentence(int(start), text))
243 |
244 | pred_event_file = 'epoch_3_dev_f1_57.9994_weights.bin_test_pred_events.json'
245 | # pred_event_file = 'test_filtered.json'
246 |
247 | for best_save_weight in save_weights:
248 | logger.info(f'loading weights from {best_save_weight}...')
249 | model.load_state_dict(torch.load(os.path.join(args.output_dir, best_save_weight)))
250 | logger.info(f'predicting coref labels of {best_save_weight}...')
251 |
252 | results = []
253 | model.eval()
254 | with open(os.path.join(args.output_dir, pred_event_file), 'rt' , encoding='utf-8') as f_in:
255 | for line in tqdm(f_in.readlines()):
256 | sample = json.loads(line.strip())
257 | events_from_file = sample['events'] if pred_event_file == 'test_filtered.json' else sample['pred_label']
258 | events = [
259 | [event['start'], event['start'] + len(event['trigger']) - 1]
260 | for event in events_from_file
261 | ]
262 | sents = kbp_sent_dic[sample['doc_id']]
263 | event_dists = []
264 | for event in events_from_file:
265 | e_dist = get_event_dist(event['start'], event['start'] + len(event['trigger']) - 1, sents)
266 | assert e_dist is not None
267 | event_dists.append(e_dist)
268 | new_events, predictions, probabilities = predict(
269 | args, sample['document'], events, event_dists, model, tokenizer
270 | )
271 | results.append({
272 | "doc_id": sample['doc_id'],
273 | "document": sample['document'],
274 | "events": [
275 | {
276 | 'start': char_start,
277 | 'end': char_end,
278 | 'trigger': sample['document'][char_start:char_end+1]
279 | } for char_start, char_end in new_events
280 | ],
281 | "pred_label": predictions,
282 | "pred_prob": probabilities
283 | })
284 | save_name = '_gold_test_pred_corefs.json' if pred_event_file == 'test_filtered.json' else '_test_pred_corefs.json'
285 | with open(os.path.join(args.output_dir, best_save_weight + save_name), 'wt', encoding='utf-8') as f:
286 | for exapmle_result in results:
287 | f.write(json.dumps(exapmle_result) + '\n')
--------------------------------------------------------------------------------
/src/global_event_coref/run_global_base_with_topic.sh:
--------------------------------------------------------------------------------
1 | export OUTPUT_DIR=./Topic_M-multi-cosine_closs_results/
2 |
3 | python3 run_global_base_with_topic.py \
4 | --output_dir=$OUTPUT_DIR \
5 | --model_type=longformer \
6 | --model_checkpoint=../../../PT_MODELS/allenai/longformer-large-4096/ \
7 | --topic_model=vmf \
8 | --topic_dim=32 \
9 | --topic_inter_map=64 \
10 | --train_file=../../data/train_filtered.json \
11 | --dev_file=../../data/dev_filtered.json \
12 | --test_file=../../data/test_filtered.json \
13 | --max_seq_length=4096 \
14 | --learning_rate=1e-5 \
15 | --add_contrastive_loss \
16 | --matching_style=multi_cosine \
17 | --softmax_loss=ce \
18 | --num_train_epochs=50 \
19 | --batch_size=1 \
20 | --do_train \
21 | --warmup_proportion=0. \
22 | --seed=42
--------------------------------------------------------------------------------
/src/joint_model/arg.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def parse_args():
4 | parser = argparse.ArgumentParser()
5 |
6 | # Required parameters
7 | parser.add_argument("--output_dir", default=None, type=str, required=True,
8 | help="The output directory where the model checkpoints and predictions will be written.",
9 | )
10 | parser.add_argument("--train_file", default=None, type=str, required=True, help="The input training file.")
11 | parser.add_argument("--dev_file", default=None, type=str, required=True, help="The input evaluation file.")
12 | parser.add_argument("--test_file", default=None, type=str, required=True, help="The input testing file.")
13 |
14 | parser.add_argument("--model_type",
15 | default="longformer", type=str, required=True
16 | )
17 | parser.add_argument("--model_checkpoint",
18 | default="allenai/longformer-base-4096", type=str, required=True,
19 | help="Path to pretrained model or model identifier from huggingface.co/models",
20 | )
21 | parser.add_argument("--max_seq_length", default=4096, type=int, required=True)
22 | parser.add_argument("--matching_style", default="multi", type=str, required=True,
23 | help="how to match two event representations"
24 | )
25 |
26 | parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
27 | parser.add_argument("--do_test", action="store_true", help="Whether to run eval on the test set.")
28 | parser.add_argument("--do_predict", action="store_true", help="Whether to save predicted labels.")
29 |
30 | # Other parameters
31 | parser.add_argument("--cache_dir", default=None, type=str,
32 | help="Where do you want to store the pre-trained models downloaded from s3"
33 | )
34 | parser.add_argument("--topic_model", default='stm', type=str,
35 | choices=['stm', 'stm_bn', 'vmf']
36 | )
37 | parser.add_argument("--topic_dim", default=32, type=int)
38 | parser.add_argument("--topic_inter_map", default=64, type=int)
39 | parser.add_argument("--mention_encoder_type", default="bert", type=str)
40 | parser.add_argument("--mention_encoder_checkpoint",
41 | default="bert-base-cased", type=str,
42 | help="Path to pretrained model or model identifier from huggingface.co/models",
43 | )
44 | parser.add_argument("--max_mention_length", default=256, type=int)
45 | parser.add_argument("--add_contrastive_loss", action="store_true")
46 | parser.add_argument("--softmax_loss", default='ce', type=str,
47 | help="The loss function for softmax model.",
48 | choices=['lsr', 'focal', 'ce']
49 | )
50 |
51 | parser.add_argument("--learning_rate", default=1e-5, type=float, help="The initial learning rate for Adam.")
52 | parser.add_argument("--num_train_epochs", default=30, type=int, help="Total number of training epochs to perform.")
53 | parser.add_argument("--batch_size", default=1, type=int)
54 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
55 |
56 | parser.add_argument("--adam_beta1", default=0.9, type=float,
57 | help="Epsilon for Adam optimizer."
58 | )
59 | parser.add_argument("--adam_beta2", default=0.98, type=float,
60 | help="Epsilon for Adam optimizer."
61 | )
62 | parser.add_argument("--adam_epsilon", default=1e-8, type=float,
63 | help="Epsilon for Adam optimizer."
64 | )
65 | parser.add_argument("--warmup_proportion", default=0.1, type=float,
66 | help="Proportion of training to perform linear learning rate warmup for,E.g., 0.1 = 10% of training."
67 | )
68 | parser.add_argument("--weight_decay", default=0.01, type=float,
69 | help="Weight decay if we apply some."
70 | )
71 | args = parser.parse_args()
72 | return args
--------------------------------------------------------------------------------
/src/joint_model/modeling.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import CrossEntropyLoss
4 | from transformers import LongformerPreTrainedModel, LongformerModel
5 | from transformers import BertModel, RobertaModel
6 | from allennlp.modules.span_extractors import SelfAttentiveSpanExtractor
7 | from ..tools import LabelSmoothingCrossEntropy, FocalLoss
8 | from ..tools import SimpleTopicModel, SimpleTopicModelwithBN, SimpleTopicVMFModel
9 |
10 | MENTION_ENCODER = {
11 | 'bert': BertModel,
12 | 'roberta': RobertaModel
13 | }
14 | TOPIC_MODEL = {
15 | 'stm': SimpleTopicModel,
16 | 'stm_bn': SimpleTopicModelwithBN,
17 | 'vmf': SimpleTopicVMFModel
18 | }
19 | COSINE_SPACE_DIM = 64
20 | COSINE_SLICES = 128
21 | COSINE_FACTOR = 4
22 |
23 | class LongformerSoftmaxForEC(LongformerPreTrainedModel):
24 | def __init__(self, config, args):
25 | super().__init__(config)
26 | self.trigger_num_labels = args.trigger_num_labels
27 | self.num_labels = args.num_labels
28 | self.hidden_size = config.hidden_size
29 | self.loss_type = args.softmax_loss
30 | self.add_contrastive_loss = args.add_contrastive_loss
31 | self.use_device = args.device
32 | # encoder & pooler
33 | self.longformer = LongformerModel(config, add_pooling_layer=False)
34 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
35 | self.span_extractor = SelfAttentiveSpanExtractor(input_dim=self.hidden_size)
36 | self.td_classifier = nn.Linear(self.hidden_size, self.trigger_num_labels)
37 | # event matching
38 | self.matching_style = args.matching_style
39 | if 'cosine' not in self.matching_style:
40 | if self.matching_style == 'base':
41 | multiples = 2
42 | elif self.matching_style == 'multi':
43 | multiples = 3
44 | self.coref_classifier = nn.Linear(multiples * self.hidden_size, self.num_labels)
45 | else:
46 | self.cosine_space_dim, self.cosine_slices, self.tensor_factor = COSINE_SPACE_DIM, COSINE_SLICES, COSINE_FACTOR
47 | self.cosine_mat_p = nn.Parameter(torch.rand((self.tensor_factor, self.cosine_slices), requires_grad=True))
48 | self.cosine_mat_q = nn.Parameter(torch.rand((self.tensor_factor, self.cosine_space_dim), requires_grad=True))
49 | self.cosine_ffnn = nn.Linear(self.hidden_size, self.cosine_space_dim)
50 | if self.matching_style == 'cosine':
51 | self.coref_classifier = nn.Linear(2 * self.hidden_size + self.cosine_slices, self.num_labels)
52 | elif self.matching_style == 'multi_cosine':
53 | self.coref_classifier = nn.Linear(3 * self.hidden_size + self.cosine_slices, self.num_labels)
54 | elif self.matching_style == 'multi_dist_cosine':
55 | self.coref_classifier = nn.Linear(4 * self.hidden_size + self.cosine_slices, self.num_labels)
56 | self.post_init()
57 |
58 | def _multi_cosine(self, batch_event_1_reps, batch_event_2_reps):
59 | batch_event_1_reps = self.cosine_ffnn(batch_event_1_reps)
60 | batch_event_1_reps = batch_event_1_reps.unsqueeze(dim=2)
61 | batch_event_1_reps = self.cosine_mat_q * batch_event_1_reps
62 | batch_event_1_reps = batch_event_1_reps.permute((0, 1, 3, 2))
63 | batch_event_1_reps = torch.matmul(batch_event_1_reps, self.cosine_mat_p)
64 | batch_event_1_reps = batch_event_1_reps.permute((0, 1, 3, 2))
65 | # vector normalization
66 | norms_1 = (batch_event_1_reps ** 2).sum(axis=-1, keepdims=True) ** 0.5
67 | batch_event_1_reps = batch_event_1_reps / norms_1
68 |
69 | batch_event_2_reps = self.cosine_ffnn(batch_event_2_reps)
70 | batch_event_2_reps = batch_event_2_reps.unsqueeze(dim=2)
71 | batch_event_2_reps = self.cosine_mat_q * batch_event_2_reps
72 | batch_event_2_reps = batch_event_2_reps.permute((0, 1, 3, 2))
73 | batch_event_2_reps = torch.matmul(batch_event_2_reps, self.cosine_mat_p)
74 | batch_event_2_reps = batch_event_2_reps.permute((0, 1, 3, 2))
75 | # vector normalization
76 | norms_2 = (batch_event_2_reps ** 2).sum(axis=-1, keepdims=True) ** 0.5
77 | batch_event_2_reps = batch_event_2_reps / norms_2
78 |
79 | return torch.sum(batch_event_1_reps * batch_event_2_reps, dim=-1)
80 |
81 | def _cal_circle_loss(self, event_1_reps, event_2_reps, coref_labels, l=20.):
82 | norms_1 = (event_1_reps ** 2).sum(axis=1, keepdims=True) ** 0.5
83 | event_1_reps = event_1_reps / norms_1
84 | norms_2 = (event_2_reps ** 2).sum(axis=1, keepdims=True) ** 0.5
85 | event_2_reps = event_2_reps / norms_2
86 | event_cos = torch.sum(event_1_reps * event_2_reps, dim=1) * l
87 | # calculate the difference between each pair of Cosine values
88 | event_cos_diff = event_cos[:, None] - event_cos[None, :]
89 | # find (noncoref, coref) index
90 | select_idx = coref_labels[:, None] < coref_labels[None, :]
91 | select_idx = select_idx.float()
92 |
93 | event_cos_diff = event_cos_diff - (1 - select_idx) * 1e12
94 | event_cos_diff = event_cos_diff.view(-1)
95 | event_cos_diff = torch.cat((torch.tensor([0.0], device=self.use_device), event_cos_diff), dim=0)
96 | return torch.logsumexp(event_cos_diff, dim=0)
97 |
98 | def _matching_func(self, batch_event_1_reps, batch_event_2_reps):
99 | if self.matching_style == 'base':
100 | batch_seq_reps = torch.cat([batch_event_1_reps, batch_event_2_reps], dim=-1)
101 | elif self.matching_style == 'multi':
102 | batch_e1_e2_multi = batch_event_1_reps * batch_event_2_reps
103 | batch_seq_reps = torch.cat([batch_event_1_reps, batch_event_2_reps, batch_e1_e2_multi], dim=-1)
104 | elif self.matching_style == 'cosine':
105 | batch_multi_cosine = self._multi_cosine(batch_event_1_reps, batch_event_2_reps)
106 | batch_seq_reps = torch.cat([batch_event_1_reps, batch_event_2_reps, batch_multi_cosine], dim=-1)
107 | elif self.matching_style == 'multi_cosine':
108 | batch_e1_e2_multi = batch_event_1_reps * batch_event_2_reps
109 | batch_multi_cosine = self._multi_cosine(batch_event_1_reps, batch_event_2_reps)
110 | batch_seq_reps = torch.cat([batch_event_1_reps, batch_event_2_reps, batch_e1_e2_multi, batch_multi_cosine], dim=-1)
111 | elif self.matching_style == 'multi_dist_cosine':
112 | batch_e1_e2_multi = batch_event_1_reps * batch_event_2_reps
113 | batch_e1_e2_dist = torch.abs(batch_event_1_reps - batch_event_2_reps)
114 | batch_multi_cosine = self._multi_cosine(batch_event_1_reps, batch_event_2_reps)
115 | batch_seq_reps = torch.cat([batch_event_1_reps, batch_event_2_reps, batch_e1_e2_multi, batch_e1_e2_dist, batch_multi_cosine], dim=-1)
116 | return batch_seq_reps
117 |
118 | def forward(self, batch_inputs, batch_events=None, batch_td_labels=None, batch_event_cluster_ids=None):
119 | outputs = self.longformer(**batch_inputs)
120 | sequence_output = outputs[0]
121 | sequence_output = self.dropout(sequence_output)
122 | # predict trigger
123 | td_logits = self.td_classifier(sequence_output)
124 | if batch_events is None:
125 | return None, td_logits
126 | # construct event pairs (event_1, event_2)
127 | batch_event_1_list, batch_event_2_list = [], []
128 | max_len, batch_event_mask = 0, []
129 | if batch_event_cluster_ids is not None:
130 | batch_coref_labels = []
131 | for events, event_cluster_ids in zip(batch_events, batch_event_cluster_ids):
132 | event_1_list, event_2_list, coref_labels = [], [], []
133 | for i in range(len(events) - 1):
134 | for j in range(i + 1, len(events)):
135 | event_1_list.append(events[i])
136 | event_2_list.append(events[j])
137 | cluster_id_1, cluster_id_2 = event_cluster_ids[i], event_cluster_ids[j]
138 | coref_labels.append(1 if cluster_id_1 == cluster_id_2 else 0)
139 | max_len = max(max_len, len(coref_labels))
140 | batch_event_1_list.append(event_1_list)
141 | batch_event_2_list.append(event_2_list)
142 | batch_coref_labels.append(coref_labels)
143 | batch_event_mask.append([1] * len(coref_labels))
144 | # padding
145 | for b_idx in range(len(batch_coref_labels)):
146 | pad_length = max_len - len(batch_coref_labels[b_idx]) if max_len > 0 else 1
147 | batch_event_1_list[b_idx] += [[0, 0]] * pad_length
148 | batch_event_2_list[b_idx] += [[0, 0]] * pad_length
149 | batch_coref_labels[b_idx] += [0] * pad_length
150 | batch_event_mask[b_idx] += [0] * pad_length
151 | else:
152 | for events in batch_events:
153 | event_1_list, event_2_list = [], []
154 | for i in range(len(events) - 1):
155 | for j in range(i + 1, len(events)):
156 | event_1_list.append(events[i])
157 | event_2_list.append(events[j])
158 | max_len = max(max_len, len(event_1_list))
159 | batch_event_1_list.append(event_1_list)
160 | batch_event_2_list.append(event_2_list)
161 | batch_event_mask.append([1] * len(event_1_list))
162 | # padding
163 | for b_idx in range(len(batch_event_mask)):
164 | pad_length = max_len - len(batch_event_mask[b_idx]) if max_len > 0 else 1
165 | batch_event_1_list[b_idx] += [[0, 0]] * pad_length
166 | batch_event_2_list[b_idx] += [[0, 0]] * pad_length
167 | batch_event_mask[b_idx] += [0] * pad_length
168 | # extract events & predict coref
169 | batch_event_1 = torch.tensor(batch_event_1_list).to(self.use_device)
170 | batch_event_2 = torch.tensor(batch_event_2_list).to(self.use_device)
171 | batch_mask = torch.tensor(batch_event_mask).to(self.use_device)
172 | batch_event_1_reps = self.span_extractor(sequence_output, batch_event_1, span_indices_mask=batch_mask)
173 | batch_event_2_reps = self.span_extractor(sequence_output, batch_event_2, span_indices_mask=batch_mask)
174 | batch_seq_reps = self._matching_func(batch_event_1_reps, batch_event_2_reps)
175 | coref_logits = self.coref_classifier(batch_seq_reps)
176 | # calculate loss
177 | loss, batch_ec_labels = None, None
178 | attention_mask = batch_inputs['attention_mask']
179 | if batch_event_cluster_ids is not None and max_len > 0:
180 | assert self.loss_type in ['lsr', 'focal', 'ce']
181 | if self.loss_type == 'lsr':
182 | loss_fct = LabelSmoothingCrossEntropy()
183 | elif self.loss_type == 'focal':
184 | loss_fct = FocalLoss()
185 | else:
186 | loss_fct = CrossEntropyLoss()
187 | # trigger detection loss
188 | active_td_loss = attention_mask.view(-1) == 1
189 | active_td_logits = td_logits.view(-1, self.trigger_num_labels)[active_td_loss]
190 | active_td_labels = batch_td_labels.view(-1)[active_td_loss]
191 | loss_td = loss_fct(active_td_logits, active_td_labels)
192 | # event coreference loss
193 | active_coref_loss = batch_mask.view(-1) == 1
194 | active_coref_logits = coref_logits.view(-1, self.num_labels)[active_coref_loss]
195 | batch_ec_labels = torch.tensor(batch_coref_labels).to(self.use_device)
196 | active_coref_labels = batch_ec_labels.view(-1)[active_coref_loss]
197 | loss_coref = loss_fct(active_coref_logits, active_coref_labels)
198 | if self.add_contrastive_loss:
199 | active_event_1_reps = batch_event_1_reps.view(-1, self.hidden_size)[active_coref_loss]
200 | active_event_2_reps = batch_event_2_reps.view(-1, self.hidden_size)[active_coref_loss]
201 | loss_contrasive = self._cal_circle_loss(active_event_1_reps, active_event_2_reps, active_coref_labels)
202 | loss = torch.log(1 + loss_td) + torch.log(1 + loss_coref) + 0.2 * loss_contrasive
203 | else:
204 | loss = torch.log(1 + loss_td) + torch.log(1 + loss_coref)
205 | return loss, td_logits, coref_logits, attention_mask, batch_td_labels, batch_mask, batch_ec_labels
206 |
--------------------------------------------------------------------------------
/src/joint_model/run_joint_base.sh:
--------------------------------------------------------------------------------
1 | export OUTPUT_DIR=./M-multi-cosine_results/
2 |
3 | python3 run_joint_base.py \
4 | --output_dir=$OUTPUT_DIR \
5 | --model_type=longformer \
6 | --model_checkpoint=../../../PT_MODELS/allenai/longformer-large-4096/ \
7 | --train_file=../../data/train_filtered.json \
8 | --dev_file=../../data/dev_filtered.json \
9 | --test_file=../../data/test_filtered.json \
10 | --max_seq_length=4096 \
11 | --learning_rate=1e-5 \
12 | --add_contrastive_loss \
13 | --matching_style=multi_cosine \
14 | --softmax_loss=ce \
15 | --num_train_epochs=30 \
16 | --batch_size=1 \
17 | --do_train \
18 | --warmup_proportion=0. \
19 | --seed=42
--------------------------------------------------------------------------------
/src/local_event_coref/arg.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def parse_args():
4 | parser = argparse.ArgumentParser()
5 |
6 | # Required parameters
7 | parser.add_argument("--output_dir", default=None, type=str, required=True,
8 | help="The output directory where the model checkpoints and predictions will be written.",
9 | )
10 | parser.add_argument("--train_file", default=None, type=str, required=True, help="The input training file.")
11 | parser.add_argument("--dev_file", default=None, type=str, required=True, help="The input evaluation file.")
12 | parser.add_argument("--test_file", default=None, type=str, required=True, help="The input testing file.")
13 |
14 | parser.add_argument("--model_type",
15 | default="bert", type=str, required=True
16 | )
17 | parser.add_argument("--model_checkpoint",
18 | default="bert-large-cased/", type=str, required=True,
19 | help="Path to pretrained model or model identifier from huggingface.co/models",
20 | )
21 | parser.add_argument("--max_seq_length", default=512, type=int, required=True)
22 | parser.add_argument("--matching_style", default="multi", type=str, required=True,
23 | help="how to match two event representations"
24 | )
25 |
26 | parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
27 | parser.add_argument("--do_test", action="store_true", help="Whether to run eval on the test set.")
28 | parser.add_argument("--do_predict", action="store_true", help="Whether to save predicted labels.")
29 |
30 | # Other parameters
31 | parser.add_argument("--cache_dir", default=None, type=str,
32 | help="Where do you want to store the pre-trained models downloaded from s3"
33 | )
34 | parser.add_argument("--topic_model", default='stm', type=str,
35 | choices=['stm', 'stm_bn', 'vmf']
36 | )
37 | parser.add_argument("--topic_dim", default=32, type=int)
38 | parser.add_argument("--topic_inter_map", default=64, type=int)
39 | parser.add_argument("--softmax_loss", default='ce', type=str,
40 | help="The loss function for softmax model.",
41 | choices=['lsr', 'focal', 'ce']
42 | )
43 |
44 | parser.add_argument("--learning_rate", default=1e-5, type=float, help="The initial learning rate for Adam.")
45 | parser.add_argument("--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform.")
46 | parser.add_argument("--batch_size", default=4, type=int)
47 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
48 |
49 | parser.add_argument("--adam_beta1", default=0.9, type=float,
50 | help="Epsilon for Adam optimizer."
51 | )
52 | parser.add_argument("--adam_beta2", default=0.98, type=float,
53 | help="Epsilon for Adam optimizer."
54 | )
55 | parser.add_argument("--adam_epsilon", default=1e-8, type=float,
56 | help="Epsilon for Adam optimizer."
57 | )
58 | parser.add_argument("--warmup_proportion", default=0.1, type=float,
59 | help="Proportion of training to perform linear learning rate warmup for,E.g., 0.1 = 10% of training."
60 | )
61 | parser.add_argument("--weight_decay", default=0.01, type=float,
62 | help="Weight decay if we apply some."
63 | )
64 | args = parser.parse_args()
65 | return args
66 |
--------------------------------------------------------------------------------
/src/local_event_coref/run_local_base.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import torch
4 | import json
5 | from collections import namedtuple, defaultdict
6 | from tqdm.auto import tqdm
7 | from transformers import AutoConfig, AutoTokenizer
8 | from transformers import AdamW, get_scheduler
9 | from sklearn.metrics import classification_report
10 | import sys
11 | sys.path.append('../../')
12 | from src.tools import seed_everything, NpEncoder
13 | from src.local_event_coref.arg import parse_args
14 | from src.local_event_coref.data import KBPCorefPair, get_dataLoader
15 | from src.local_event_coref.modeling import BertForPairwiseEC, RobertaForPairwiseEC
16 |
17 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
18 | datefmt='%Y/%m/%d %H:%M:%S',
19 | level=logging.INFO)
20 | logger = logging.getLogger("Model")
21 | Sentence = namedtuple("Sentence", ["start", "text"])
22 |
23 | MODEL_CLASSES = {
24 | 'bert': BertForPairwiseEC,
25 | 'spanbert': BertForPairwiseEC,
26 | 'roberta': RobertaForPairwiseEC
27 | }
28 |
29 | def to_device(args, batch_data):
30 | new_batch_data = {}
31 | for k, v in batch_data.items():
32 | if k == 'batch_inputs':
33 | new_batch_data[k] = {
34 | k_: v_.to(args.device) for k_, v_ in v.items()
35 | }
36 | else:
37 | new_batch_data[k] = torch.tensor(v).to(args.device)
38 | return new_batch_data
39 |
40 | def train_loop(args, dataloader, model, optimizer, lr_scheduler, epoch, total_loss):
41 | progress_bar = tqdm(range(len(dataloader)))
42 | progress_bar.set_description(f'loss: {0:>7f}')
43 | finish_step_num = epoch * len(dataloader)
44 |
45 | model.train()
46 | for step, batch_data in enumerate(dataloader, start=1):
47 | batch_data = to_device(args, batch_data)
48 | outputs = model(**batch_data)
49 | loss = outputs[0]
50 |
51 | optimizer.zero_grad()
52 | loss.backward()
53 | optimizer.step()
54 | lr_scheduler.step()
55 |
56 | total_loss += loss.item()
57 | progress_bar.set_description(f'loss: {total_loss/(finish_step_num + step):>7f}')
58 | progress_bar.update(1)
59 | return total_loss
60 |
61 | def test_loop(args, dataloader, model):
62 | true_labels, true_predictions = [], []
63 | model.eval()
64 | with torch.no_grad():
65 | for batch_data in tqdm(dataloader):
66 | batch_data = to_device(args, batch_data)
67 | outputs = model(**batch_data)
68 | logits = outputs[1]
69 |
70 | predictions = logits.argmax(dim=-1).cpu().numpy().tolist()
71 | labels = batch_data['labels'].cpu().numpy()
72 | true_predictions += predictions
73 | true_labels += [int(label) for label in labels]
74 | return classification_report(true_labels, true_predictions, output_dict=True)
75 |
76 | def train(args, train_dataset, dev_dataset, model, tokenizer):
77 | """ Train the model """
78 | train_dataloader = get_dataLoader(args, train_dataset, tokenizer, shuffle=True)
79 | dev_dataloader = get_dataLoader(args, dev_dataset, tokenizer, shuffle=False)
80 | t_total = len(train_dataloader) * args.num_train_epochs
81 | # Prepare optimizer and schedule (linear warmup and decay)
82 | no_decay = ["bias", "LayerNorm.weight"]
83 | optimizer_grouped_parameters = [
84 | {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay},
85 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}
86 | ]
87 | args.warmup_steps = int(t_total * args.warmup_proportion)
88 | optimizer = AdamW(
89 | optimizer_grouped_parameters,
90 | lr=args.learning_rate,
91 | betas=(args.adam_beta1, args.adam_beta2),
92 | eps=args.adam_epsilon
93 | )
94 | lr_scheduler = get_scheduler(
95 | 'linear',
96 | optimizer,
97 | num_warmup_steps=args.warmup_steps,
98 | num_training_steps=t_total
99 | )
100 | # Train!
101 | logger.info("***** Running training *****")
102 | logger.info(f"Num examples - {len(train_dataset)}")
103 | logger.info(f"Num Epochs - {args.num_train_epochs}")
104 | logger.info(f"Total optimization steps - {t_total}")
105 | with open(os.path.join(args.output_dir, 'args.txt'), 'wt') as f:
106 | f.write(str(args))
107 |
108 | total_loss = 0.
109 | best_f1 = 0.
110 | for epoch in range(args.num_train_epochs):
111 | print(f"Epoch {epoch+1}/{args.num_train_epochs}\n-------------------------------")
112 | total_loss = train_loop(args, train_dataloader, model, optimizer, lr_scheduler, epoch, total_loss)
113 | metrics = test_loop(args, dev_dataloader, model)
114 | dev_p, dev_r, dev_f1 = metrics['1']['precision'], metrics['1']['recall'], metrics['1']['f1-score']
115 | logger.info(f'Dev: P - {(100*dev_p):0.4f} R - {(100*dev_r):0.4f} F1 - {(100*dev_f1):0.4f}')
116 | if dev_f1 > best_f1:
117 | best_f1 = dev_f1
118 | logger.info(f'saving new weights to {args.output_dir}...\n')
119 | save_weight = f'epoch_{epoch+1}_dev_f1_{(100*dev_f1):0.4f}_weights.bin'
120 | torch.save(model.state_dict(), os.path.join(args.output_dir, save_weight))
121 | with open(os.path.join(args.output_dir, 'dev_metrics.txt'), 'at') as f:
122 | f.write(f'epoch_{epoch+1}\n' + json.dumps(metrics, cls=NpEncoder) + '\n\n')
123 | logger.info("Done!")
124 |
125 | def test(args, test_dataset, model, tokenizer, save_weights:list):
126 | test_dataloader = get_dataLoader(args, test_dataset, tokenizer, batch_size=1, shuffle=False)
127 | logger.info('***** Running testing *****')
128 | for save_weight in save_weights:
129 | logger.info(f'loading weights from {save_weight}...')
130 | model.load_state_dict(torch.load(os.path.join(args.output_dir, save_weight)))
131 | metrics = test_loop(args, test_dataloader, model)
132 | with open(os.path.join(args.output_dir, 'test_metrics.txt'), 'at') as f:
133 | f.write(f'{save_weight}\n{json.dumps(metrics, cls=NpEncoder)}\n\n')
134 |
135 | def predict(args, sent_1, sent_2, e1_char_start, e1_char_end, e2_char_start, e2_char_end, model, tokenizer):
136 |
137 | def _cut_sent(sent, e_char_start, e_char_end, max_length):
138 | before = ' '.join([c for c in sent[:e_char_start].split(' ') if c != ''][-max_length:]).strip()
139 | trigger = sent[e_char_start:e_char_end+1]
140 | after = ' '.join([c for c in sent[e_char_end+1:].split(' ') if c != ''][:max_length]).strip()
141 | new_sent, new_char_start, new_char_end = before + ' ' + trigger + ' ' + after, len(before) + 1, len(before) + len(trigger)
142 | assert new_sent[new_char_start:new_char_end+1] == trigger
143 | return new_sent, new_char_start, new_char_end
144 |
145 | max_mention_length = (args.max_seq_length - 50) // 4
146 | sent_1, e1_char_start, e1_char_end = _cut_sent(sent_1, e1_char_start, e1_char_end, max_mention_length)
147 | sent_2, e2_char_start, e2_char_end = _cut_sent(sent_2, e2_char_start, e2_char_end, max_mention_length)
148 | inputs = tokenizer(
149 | sent_1,
150 | sent_2,
151 | max_length=args.max_seq_length,
152 | truncation=True,
153 | return_tensors="pt"
154 | )
155 | e1_token_start = inputs.char_to_token(e1_char_start, sequence_index=0)
156 | if not e1_token_start:
157 | e1_token_start = inputs.char_to_token(e1_char_start + 1, sequence_index=0)
158 | e1_token_end = inputs.char_to_token(e1_char_end, sequence_index=0)
159 | e2_token_start = inputs.char_to_token(e2_char_start, sequence_index=1)
160 | if not e2_token_start:
161 | e2_token_start = inputs.char_to_token(e2_char_start + 1, sequence_index=1)
162 | e2_token_end = inputs.char_to_token(e2_char_end, sequence_index=1)
163 | assert e1_token_start and e1_token_end and e2_token_start and e2_token_end
164 | inputs = {
165 | 'batch_inputs': inputs,
166 | 'batch_e1_idx': [[[e1_token_start, e1_token_end]]],
167 | 'batch_e2_idx': [[[e2_token_start, e2_token_end]]]
168 | }
169 | inputs = to_device(args, inputs)
170 | with torch.no_grad():
171 | outputs = model(**inputs)
172 | logits = outputs[1]
173 | pred = int(logits.argmax(dim=-1)[0].cpu().numpy())
174 | prob = torch.nn.functional.softmax(logits, dim=-1)[0].cpu().numpy().tolist()
175 | return pred, prob[pred]
176 |
177 | def get_event_sent(e_start, e_end, sents):
178 | for sent in sents:
179 | sent_end = sent.start + len(sent.text) - 1
180 | if e_start >= sent.start and e_end <= sent_end:
181 | return sent.text, e_start - sent.start, e_end - sent.start
182 | return None, None, None
183 |
184 | if __name__ == '__main__':
185 | args = parse_args()
186 | if args.do_train and os.path.exists(args.output_dir) and os.listdir(args.output_dir):
187 | raise ValueError(f'Output directory ({args.output_dir}) already exists and is not empty.')
188 | if not os.path.exists(args.output_dir):
189 | os.mkdir(args.output_dir)
190 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
191 | args.n_gpu = torch.cuda.device_count()
192 | logger.warning(f'Using {args.device} device, n_gpu: {args.n_gpu}')
193 | # Set seed
194 | seed_everything(args.seed)
195 | # Load pretrained model and tokenizer
196 | logger.info(f'loading pretrained model and tokenizer of {args.model_type} ...')
197 | config = AutoConfig.from_pretrained(args.model_checkpoint, cache_dir=args.cache_dir)
198 | tokenizer = AutoTokenizer.from_pretrained(args.model_checkpoint, cache_dir=args.cache_dir)
199 | args.num_labels = 2
200 | model = MODEL_CLASSES[args.model_type].from_pretrained(
201 | args.model_checkpoint,
202 | config=config,
203 | cache_dir=args.cache_dir,
204 | args=args
205 | ).to(args.device)
206 | # Training
207 | if args.do_train:
208 | train_dataset = KBPCorefPair(args.train_file)
209 | dev_dataset = KBPCorefPair(args.dev_file)
210 | train(args, train_dataset, dev_dataset, model, tokenizer)
211 | # Testing
212 | save_weights = [file for file in os.listdir(args.output_dir) if file.endswith('.bin')]
213 | if args.do_test:
214 | test_dataset = KBPCorefPair(args.test_file)
215 | test(args, test_dataset, model, tokenizer, save_weights)
216 | # Predicting
217 | if args.do_predict:
218 | kbp_sent_dic = defaultdict(list) # {filename: [Sentence]}
219 | with open(os.path.join('../../data/kbp_sent.txt'), 'rt', encoding='utf-8') as sents:
220 | for line in sents:
221 | doc_id, start, text = line.strip().split('\t')
222 | kbp_sent_dic[doc_id].append(Sentence(int(start), text))
223 |
224 | pred_event_file = 'epoch_3_dev_f1_57.9994_weights.bin_test_pred_events.json'
225 |
226 | for best_save_weight in save_weights:
227 | logger.info(f'loading weights from {best_save_weight}...')
228 | model.load_state_dict(torch.load(os.path.join(args.output_dir, best_save_weight)))
229 | logger.info(f'predicting coref labels of {best_save_weight}...')
230 |
231 | results = []
232 | model.eval()
233 | pred_event_filepath = os.path.join(args.output_dir, pred_event_file)
234 | with open(pred_event_filepath, 'rt' , encoding='utf-8') as f_in:
235 | for line in tqdm(f_in.readlines()):
236 | sample = json.loads(line.strip())
237 | events = [
238 | (event['start'], event['start'] + len(event['trigger']) - 1, event['trigger'])
239 | for event in sample['pred_label']
240 | ]
241 | sents = kbp_sent_dic[sample['doc_id']]
242 | new_events = []
243 | for e_start, e_end, e_trigger in events:
244 | e_sent, e_new_start, e_new_end = get_event_sent(e_start, e_end, sents)
245 | assert e_sent is not None and e_sent[e_new_start:e_new_end+1] == e_trigger
246 | new_events.append((e_new_start, e_new_end, e_sent))
247 | predictions, probabilities = [], []
248 | for i in range(len(new_events) - 1):
249 | for j in range(i + 1, len(new_events)):
250 | e1_char_start, e1_char_end, sent_1 = new_events[i]
251 | e2_char_start, e2_char_end, sent_2 = new_events[j]
252 | pred, prob = predict(args,
253 | sent_1, sent_2,
254 | e1_char_start, e1_char_end,
255 | e2_char_start, e2_char_end,
256 | model, tokenizer
257 | )
258 | predictions.append(pred)
259 | probabilities.append(prob)
260 | results.append({
261 | "doc_id": sample['doc_id'],
262 | "document": sample['document'],
263 | "events": [
264 | {
265 | 'start': char_start,
266 | 'end': char_end,
267 | 'trigger': trigger
268 | } for char_start, char_end, trigger in events
269 | ],
270 | "pred_label": predictions,
271 | "pred_prob": probabilities
272 | })
273 | with open(os.path.join(args.output_dir, best_save_weight + '_test_pred_corefs.json'), 'wt', encoding='utf-8') as f:
274 | for exapmle_result in results:
275 | f.write(json.dumps(exapmle_result) + '\n')
276 |
--------------------------------------------------------------------------------
/src/local_event_coref/run_local_base.sh:
--------------------------------------------------------------------------------
1 | export OUTPUT_DIR=./bert_results/
2 |
3 | python3 run_local_base.py \
4 | --output_dir=$OUTPUT_DIR \
5 | --model_type=bert \
6 | --model_checkpoint=../../../PT_MODELS/bert-large-cased/ \
7 | --train_file=../../data/train_filtered.json \
8 | --dev_file=../../data/dev_filtered.json \
9 | --test_file=../../data/test_filtered.json \
10 | --max_seq_length=512 \
11 | --learning_rate=1e-5 \
12 | --matching_style=multi \
13 | --softmax_loss=ce \
14 | --num_train_epochs=10 \
15 | --batch_size=4 \
16 | --do_train \
17 | --warmup_proportion=0. \
18 | --seed=42
--------------------------------------------------------------------------------
/src/local_event_coref/run_local_base_with_mask.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import torch
4 | import json
5 | from collections import namedtuple, defaultdict
6 | from tqdm.auto import tqdm
7 | from transformers import AutoConfig, AutoTokenizer
8 | from transformers import AdamW, get_scheduler
9 | from sklearn.metrics import classification_report
10 | import sys
11 | sys.path.append('../../')
12 | from src.tools import seed_everything, NpEncoder
13 | from src.local_event_coref.arg import parse_args
14 | from src.local_event_coref.data import KBPCorefPair, get_dataLoader, SUBTYPES
15 | from src.local_event_coref.modeling import BertForPairwiseECWithMask, RobertaForPairwiseECWithMask
16 |
17 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
18 | datefmt='%Y/%m/%d %H:%M:%S',
19 | level=logging.INFO)
20 | logger = logging.getLogger("Model")
21 | Sentence = namedtuple("Sentence", ["start", "text"])
22 |
23 | MODEL_CLASSES = {
24 | 'bert': BertForPairwiseECWithMask,
25 | 'spanbert': BertForPairwiseECWithMask,
26 | 'roberta': RobertaForPairwiseECWithMask
27 | }
28 |
29 | def to_device(args, batch_data):
30 | new_batch_data = {}
31 | for k, v in batch_data.items():
32 | if k in ['batch_inputs', 'batch_inputs_with_mask']:
33 | new_batch_data[k] = {
34 | k_: v_.to(args.device) for k_, v_ in v.items()
35 | }
36 | else:
37 | new_batch_data[k] = torch.tensor(v).to(args.device)
38 | return new_batch_data
39 |
40 | def train_loop(args, dataloader, model, optimizer, lr_scheduler, epoch, total_loss):
41 | progress_bar = tqdm(range(len(dataloader)))
42 | progress_bar.set_description(f'loss: {0:>7f}')
43 | finish_step_num = epoch * len(dataloader)
44 |
45 | model.train()
46 | for step, batch_data in enumerate(dataloader, start=1):
47 | batch_data = to_device(args, batch_data)
48 | outputs = model(**batch_data)
49 | loss = outputs[0]
50 |
51 | optimizer.zero_grad()
52 | loss.backward()
53 | optimizer.step()
54 | lr_scheduler.step()
55 |
56 | total_loss += loss.item()
57 | progress_bar.set_description(f'loss: {total_loss/(finish_step_num + step):>7f}')
58 | progress_bar.update(1)
59 | return total_loss
60 |
61 | def test_loop(args, dataloader, model):
62 | true_labels, true_predictions = [], []
63 | model.eval()
64 | with torch.no_grad():
65 | for batch_data in tqdm(dataloader):
66 | batch_data = to_device(args, batch_data)
67 | outputs = model(**batch_data)
68 | logits = outputs[1]
69 |
70 | predictions = logits.argmax(dim=-1).cpu().numpy().tolist()
71 | labels = batch_data['labels'].cpu().numpy()
72 | true_predictions += predictions
73 | true_labels += [int(label) for label in labels]
74 | return classification_report(true_labels, true_predictions, output_dict=True)
75 |
76 | def train(args, train_dataset, dev_dataset, model, tokenizer):
77 | """ Train the model """
78 | train_dataloader = get_dataLoader(args, train_dataset, tokenizer, shuffle=True, collote_fn_type='with_mask')
79 | dev_dataloader = get_dataLoader(args, dev_dataset, tokenizer, shuffle=False, collote_fn_type='with_mask')
80 | t_total = len(train_dataloader) * args.num_train_epochs
81 | # Prepare optimizer and schedule (linear warmup and decay)
82 | no_decay = ["bias", "LayerNorm.weight"]
83 | optimizer_grouped_parameters = [
84 | {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay},
85 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}
86 | ]
87 | args.warmup_steps = int(t_total * args.warmup_proportion)
88 | optimizer = AdamW(
89 | optimizer_grouped_parameters,
90 | lr=args.learning_rate,
91 | betas=(args.adam_beta1, args.adam_beta2),
92 | eps=args.adam_epsilon
93 | )
94 | lr_scheduler = get_scheduler(
95 | 'linear',
96 | optimizer,
97 | num_warmup_steps=args.warmup_steps,
98 | num_training_steps=t_total
99 | )
100 | # Train!
101 | logger.info("***** Running training *****")
102 | logger.info(f"Num examples - {len(train_dataset)}")
103 | logger.info(f"Num Epochs - {args.num_train_epochs}")
104 | logger.info(f"Total optimization steps - {t_total}")
105 | with open(os.path.join(args.output_dir, 'args.txt'), 'wt') as f:
106 | f.write(str(args))
107 |
108 | total_loss = 0.
109 | best_f1 = 0.
110 | for epoch in range(args.num_train_epochs):
111 | print(f"Epoch {epoch+1}/{args.num_train_epochs}\n-------------------------------")
112 | total_loss = train_loop(args, train_dataloader, model, optimizer, lr_scheduler, epoch, total_loss)
113 | metrics = test_loop(args, dev_dataloader, model)
114 | dev_p, dev_r, dev_f1 = metrics['1']['precision'], metrics['1']['recall'], metrics['1']['f1-score']
115 | logger.info(f'Dev: P - {(100*dev_p):0.4f} R - {(100*dev_r):0.4f} F1 - {(100*dev_f1):0.4f}')
116 | if dev_f1 > best_f1:
117 | best_f1 = dev_f1
118 | logger.info(f'saving new weights to {args.output_dir}...\n')
119 | save_weight = f'epoch_{epoch+1}_dev_f1_{(100*dev_f1):0.4f}_weights.bin'
120 | torch.save(model.state_dict(), os.path.join(args.output_dir, save_weight))
121 | with open(os.path.join(args.output_dir, 'dev_metrics.txt'), 'at') as f:
122 | f.write(f'epoch_{epoch+1}\n' + json.dumps(metrics, cls=NpEncoder) + '\n\n')
123 | logger.info("Done!")
124 |
125 | def test(args, test_dataset, model, tokenizer, save_weights:list):
126 | test_dataloader = get_dataLoader(args, test_dataset, tokenizer, batch_size=1, shuffle=False, collote_fn_type='with_mask')
127 | logger.info('***** Running testing *****')
128 | for save_weight in save_weights:
129 | logger.info(f'loading weights from {save_weight}...')
130 | model.load_state_dict(torch.load(os.path.join(args.output_dir, save_weight)))
131 | metrics = test_loop(args, test_dataloader, model)
132 | with open(os.path.join(args.output_dir, 'test_metrics.txt'), 'at') as f:
133 | f.write(f'{save_weight}\n{json.dumps(metrics, cls=NpEncoder)}\n\n')
134 |
135 | def predict(args, sent_1, sent_2, e1_char_start, e1_char_end, e2_char_start, e2_char_end, model, tokenizer):
136 |
137 | def _cut_sent(sent, e_char_start, e_char_end, max_length):
138 | before = ' '.join([c for c in sent[:e_char_start].split(' ') if c != ''][-max_length:]).strip()
139 | trigger = sent[e_char_start:e_char_end+1]
140 | after = ' '.join([c for c in sent[e_char_end+1:].split(' ') if c != ''][:max_length]).strip()
141 | new_sent, new_char_start, new_char_end = before + ' ' + trigger + ' ' + after, len(before) + 1, len(before) + len(trigger)
142 | assert new_sent[new_char_start:new_char_end+1] == trigger
143 | return new_sent, new_char_start, new_char_end
144 |
145 | max_mention_length = (args.max_seq_length - 50) // 4
146 | sent_1, e1_char_start, e1_char_end = _cut_sent(sent_1, e1_char_start, e1_char_end, max_mention_length)
147 | sent_2, e2_char_start, e2_char_end = _cut_sent(sent_2, e2_char_start, e2_char_end, max_mention_length)
148 | inputs = tokenizer(
149 | sent_1,
150 | sent_2,
151 | max_length=args.max_seq_length,
152 | truncation=True,
153 | return_tensors="pt"
154 | )
155 | inputs_with_mask = tokenizer(
156 | sent_1,
157 | sent_2,
158 | max_length=args.max_seq_length,
159 | truncation=True,
160 | return_tensors="pt"
161 | )
162 | e1_token_start = inputs.char_to_token(e1_char_start, sequence_index=0)
163 | if not e1_token_start:
164 | e1_token_start = inputs.char_to_token(e1_char_start + 1, sequence_index=0)
165 | e1_token_end = inputs.char_to_token(e1_char_end, sequence_index=0)
166 | e2_token_start = inputs.char_to_token(e2_char_start, sequence_index=1)
167 | if not e2_token_start:
168 | e2_token_start = inputs.char_to_token(e2_char_start + 1, sequence_index=1)
169 | e2_token_end = inputs.char_to_token(e2_char_end, sequence_index=1)
170 | assert e1_token_start and e1_token_end and e2_token_start and e2_token_end
171 | inputs_with_mask['input_ids'][0][e1_token_start:e1_token_end+1] = tokenizer.mask_token_id
172 | inputs_with_mask['input_ids'][0][e2_token_start:e2_token_end+1] = tokenizer.mask_token_id
173 | inputs = {
174 | 'batch_inputs': inputs,
175 | 'batch_inputs_with_mask': inputs_with_mask,
176 | 'batch_e1_idx': [[[e1_token_start, e1_token_end]]],
177 | 'batch_e2_idx': [[[e2_token_start, e2_token_end]]]
178 | }
179 | inputs = to_device(args, inputs)
180 | with torch.no_grad():
181 | outputs = model(**inputs)
182 | logits = outputs[1]
183 | pred = int(logits.argmax(dim=-1)[0].cpu().numpy())
184 | prob = torch.nn.functional.softmax(logits, dim=-1)[0].cpu().numpy().tolist()
185 | return pred, prob[pred]
186 |
187 | def get_event_sent(e_start, e_end, sents):
188 | for sent in sents:
189 | sent_end = sent.start + len(sent.text) - 1
190 | if e_start >= sent.start and e_end <= sent_end:
191 | return sent.text, e_start - sent.start, e_end - sent.start
192 | return None, None, None
193 |
194 | if __name__ == '__main__':
195 | args = parse_args()
196 | if args.do_train and os.path.exists(args.output_dir) and os.listdir(args.output_dir):
197 | raise ValueError(f'Output directory ({args.output_dir}) already exists and is not empty.')
198 | if not os.path.exists(args.output_dir):
199 | os.mkdir(args.output_dir)
200 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
201 | args.n_gpu = torch.cuda.device_count()
202 | logger.warning(f'Using {args.device} device, n_gpu: {args.n_gpu}')
203 | # Set seed
204 | seed_everything(args.seed)
205 | # Load pretrained model and tokenizer
206 | logger.info(f'loading pretrained model and tokenizer of {args.model_type} ...')
207 | config = AutoConfig.from_pretrained(args.model_checkpoint, cache_dir=args.cache_dir)
208 | tokenizer = AutoTokenizer.from_pretrained(args.model_checkpoint, cache_dir=args.cache_dir)
209 | args.num_labels = 2
210 | args.num_subtypes = len(SUBTYPES) + 1
211 | model = MODEL_CLASSES[args.model_type].from_pretrained(
212 | args.model_checkpoint,
213 | config=config,
214 | cache_dir=args.cache_dir,
215 | args=args
216 | ).to(args.device)
217 | # Training
218 | if args.do_train:
219 | train_dataset = KBPCorefPair(args.train_file)
220 | dev_dataset = KBPCorefPair(args.dev_file)
221 | train(args, train_dataset, dev_dataset, model, tokenizer)
222 | # Testing
223 | save_weights = [file for file in os.listdir(args.output_dir) if file.endswith('.bin')]
224 | if args.do_test:
225 | test_dataset = KBPCorefPair(args.test_file)
226 | test(args, test_dataset, model, tokenizer, save_weights)
227 | # Predicting
228 | if args.do_predict:
229 | kbp_sent_dic = defaultdict(list) # {filename: [Sentence]}
230 | with open(os.path.join('../../data/kbp_sent.txt'), 'rt', encoding='utf-8') as sents:
231 | for line in sents:
232 | doc_id, start, text = line.strip().split('\t')
233 | kbp_sent_dic[doc_id].append(Sentence(int(start), text))
234 |
235 | pred_event_file = 'epoch_3_dev_f1_57.9994_weights.bin_test_pred_events.json'
236 |
237 | for best_save_weight in save_weights:
238 | logger.info(f'loading weights from {best_save_weight}...')
239 | model.load_state_dict(torch.load(os.path.join(args.output_dir, best_save_weight)))
240 | logger.info(f'predicting coref labels of {best_save_weight}...')
241 |
242 | results = []
243 | model.eval()
244 | pred_event_filepath = os.path.join(args.output_dir, pred_event_file)
245 | with open(pred_event_filepath, 'rt' , encoding='utf-8') as f_in:
246 | for line in tqdm(f_in.readlines()):
247 | sample = json.loads(line.strip())
248 | events = [
249 | (event['start'], event['start'] + len(event['trigger']) - 1, event['trigger'])
250 | for event in sample['pred_label']
251 | ]
252 | sents = kbp_sent_dic[sample['doc_id']]
253 | new_events = []
254 | for e_start, e_end, e_trigger in events:
255 | e_sent, e_new_start, e_new_end = get_event_sent(e_start, e_end, sents)
256 | assert e_sent is not None and e_sent[e_new_start:e_new_end+1] == e_trigger
257 | new_events.append((e_new_start, e_new_end, e_sent))
258 | predictions, probabilities = [], []
259 | for i in range(len(new_events) - 1):
260 | for j in range(i + 1, len(new_events)):
261 | e1_char_start, e1_char_end, sent_1 = new_events[i]
262 | e2_char_start, e2_char_end, sent_2 = new_events[j]
263 | pred, prob = predict(args,
264 | sent_1, sent_2,
265 | e1_char_start, e1_char_end,
266 | e2_char_start, e2_char_end,
267 | model, tokenizer
268 | )
269 | predictions.append(pred)
270 | probabilities.append(prob)
271 | results.append({
272 | "doc_id": sample['doc_id'],
273 | "document": sample['document'],
274 | "events": [
275 | {
276 | 'start': char_start,
277 | 'end': char_end,
278 | 'trigger': trigger
279 | } for char_start, char_end, trigger in events
280 | ],
281 | "pred_label": predictions,
282 | "pred_prob": probabilities
283 | })
284 | with open(os.path.join(args.output_dir, best_save_weight + '_test_pred_corefs.json'), 'wt', encoding='utf-8') as f:
285 | for exapmle_result in results:
286 | f.write(json.dumps(exapmle_result) + '\n')
287 |
--------------------------------------------------------------------------------
/src/local_event_coref/run_local_base_with_mask.sh:
--------------------------------------------------------------------------------
1 | export OUTPUT_DIR=./mask_bert_results/
2 |
3 | python3 run_local_base_with_mask.py \
4 | --output_dir=$OUTPUT_DIR \
5 | --model_type=bert \
6 | --model_checkpoint=../../../PT_MODELS/bert-large-cased/ \
7 | --train_file=../../data/train_filtered.json \
8 | --dev_file=../../data/dev_filtered.json \
9 | --test_file=../../data/test_filtered.json \
10 | --max_seq_length=512 \
11 | --learning_rate=1e-5 \
12 | --matching_style=multi \
13 | --softmax_loss=ce \
14 | --num_train_epochs=10 \
15 | --batch_size=4 \
16 | --do_train \
17 | --warmup_proportion=0. \
18 | --seed=42
--------------------------------------------------------------------------------
/src/local_event_coref/run_local_base_with_mask_topic.sh:
--------------------------------------------------------------------------------
1 | export OUTPUT_DIR=./mask_topic_bert_results/
2 |
3 | python3 run_local_base_with_mask_topic.py \
4 | --output_dir=$OUTPUT_DIR \
5 | --model_type=bert \
6 | --model_checkpoint=../../../PT_MODELS/bert-large-cased/ \
7 | --topic_model=vmf \
8 | --topic_dim=32 \
9 | --topic_inter_map=64 \
10 | --train_file=../../data/train_filtered.json \
11 | --dev_file=../../data/dev_filtered.json \
12 | --test_file=../../data/test_filtered.json \
13 | --max_seq_length=512 \
14 | --learning_rate=1e-5 \
15 | --matching_style=multi \
16 | --softmax_loss=ce \
17 | --num_train_epochs=10 \
18 | --batch_size=4 \
19 | --do_train \
20 | --warmup_proportion=0. \
21 | --seed=42
--------------------------------------------------------------------------------
/src/local_event_coref/run_local_base_with_topic.sh:
--------------------------------------------------------------------------------
1 | export OUTPUT_DIR=./topic_bert_results/
2 |
3 | python3 run_local_base_with_topic.py \
4 | --output_dir=$OUTPUT_DIR \
5 | --model_type=bert \
6 | --model_checkpoint=../../../PT_MODELS/bert-large-cased/ \
7 | --topic_model=vmf \
8 | --topic_dim=32 \
9 | --topic_inter_map=64 \
10 | --train_file=../../data/train_filtered.json \
11 | --dev_file=../../data/dev_filtered.json \
12 | --test_file=../../data/test_filtered.json \
13 | --max_seq_length=512 \
14 | --learning_rate=1e-5 \
15 | --matching_style=multi \
16 | --softmax_loss=ce \
17 | --num_train_epochs=10 \
18 | --batch_size=4 \
19 | --do_train \
20 | --warmup_proportion=0. \
21 | --seed=42
--------------------------------------------------------------------------------
/src/trigger_detection/arg.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def parse_args():
4 | parser = argparse.ArgumentParser()
5 |
6 | # Required parameters
7 | parser.add_argument("--output_dir", default=None, type=str, required=True,
8 | help="The output directory where the model checkpoints and predictions will be written.",
9 | )
10 | parser.add_argument("--train_file", default=None, type=str, required=True, help="The input training file.")
11 | parser.add_argument("--dev_file", default=None, type=str, required=True, help="The input evaluation file.")
12 | parser.add_argument("--test_file", default=None, type=str, required=True, help="The input testing file.")
13 |
14 | parser.add_argument("--model_type",
15 | default="longformer", type=str, required=True
16 | )
17 | parser.add_argument("--model_checkpoint",
18 | default="allenai/longformer-base-4096", type=str, required=True,
19 | help="Path to pretrained model or model identifier from huggingface.co/models",
20 | )
21 | parser.add_argument("--max_seq_length", default=4096, type=int, required=True)
22 |
23 | parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
24 | parser.add_argument("--do_test", action="store_true", help="Whether to run eval on the test set.")
25 | parser.add_argument("--do_predict", action="store_true", help="Whether to save predicted labels.")
26 |
27 | # Other parameters
28 | parser.add_argument("--cache_dir", default=None, type=str,
29 | help="Where do you want to store the pre-trained models downloaded from s3"
30 | )
31 | parser.add_argument("--use_ffnn_layer", action="store_true", help="Whether add FFNN before classifier.")
32 | parser.add_argument("--ffnn_size", default=-1, type=int, help="The size of mlp layer.")
33 | parser.add_argument("--softmax_loss", default='ce', type=str,
34 | help="The loss function for softmax model.",
35 | choices=['lsr', 'focal', 'ce']
36 | )
37 | parser.add_argument("--learning_rate", default=1e-5, type=float, help="The initial learning rate for Adam.")
38 | parser.add_argument("--crf_learning_rate", default=5e-5, type=float, help="The initial learning rate for crf.")
39 | parser.add_argument("--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform.")
40 | parser.add_argument("--batch_size", default=4, type=int)
41 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
42 |
43 | parser.add_argument("--adam_beta1", default=0.9, type=float,
44 | help="Epsilon for Adam optimizer."
45 | )
46 | parser.add_argument("--adam_beta2", default=0.98, type=float,
47 | help="Epsilon for Adam optimizer."
48 | )
49 | parser.add_argument("--adam_epsilon", default=1e-8, type=float,
50 | help="Epsilon for Adam optimizer."
51 | )
52 | parser.add_argument("--warmup_proportion", default=0.1, type=float,
53 | help="Proportion of training to perform linear learning rate warmup for,E.g., 0.1 = 10% of training."
54 | )
55 | parser.add_argument("--weight_decay", default=0.01, type=float,
56 | help="Weight decay if we apply some."
57 | )
58 | args = parser.parse_args()
59 | return args
--------------------------------------------------------------------------------
/src/trigger_detection/data.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset, DataLoader
2 | import json
3 | import numpy as np
4 | import torch
5 |
6 | CATEGORIES = [
7 | 'artifact', 'transferownership', 'transaction', 'broadcast', 'contact', 'demonstrate', \
8 | 'injure', 'transfermoney', 'transportartifact', 'attack', 'meet', 'elect', \
9 | 'endposition', 'correspondence', 'arrestjail', 'startposition', 'transportperson', 'die'
10 | ]
11 |
12 | class KBPTrigger(Dataset):
13 | def __init__(self, data_file):
14 | self.data = self.load_data(data_file)
15 |
16 | def load_data(self, data_file):
17 | Data = []
18 | with open(data_file, 'rt', encoding='utf-8') as f:
19 | for line in f:
20 | sample = json.loads(line.strip())
21 | tags = [
22 | (event['start'], event['start'] + len(event['trigger']) - 1, event['trigger'], event['subtype'])
23 | for event in sample['events'] if event['subtype'] in CATEGORIES
24 | ]
25 | Data.append({
26 | 'id': sample['doc_id'],
27 | 'document': sample['document'],
28 | 'tags': tags
29 | })
30 | return Data
31 |
32 | def __len__(self):
33 | return len(self.data)
34 |
35 | def __getitem__(self, idx):
36 | return self.data[idx]
37 |
38 | def get_dataLoader(args, dataset, tokenizer, batch_size=None, shuffle=False):
39 |
40 | def collote_fn(batch_samples):
41 | batch_sentence, batch_tags = [], []
42 | for sample in batch_samples:
43 | batch_sentence.append(sample['document'])
44 | batch_tags.append(sample['tags'])
45 | batch_inputs = tokenizer(
46 | batch_sentence,
47 | max_length=args.max_seq_length,
48 | padding=True,
49 | truncation=True,
50 | return_tensors="pt"
51 | )
52 | batch_label = np.zeros(batch_inputs['input_ids'].shape, dtype=int)
53 | for s_idx, sentence in enumerate(batch_sentence):
54 | encoding = tokenizer(sentence, max_length=args.max_seq_length, truncation=True)
55 | for char_start, char_end, _, tag in batch_tags[s_idx]:
56 | token_start = encoding.char_to_token(char_start)
57 | token_end = encoding.char_to_token(char_end)
58 | if not token_start or not token_end:
59 | continue
60 | batch_label[s_idx][token_start] = args.label2id[f"B-{tag}"]
61 | batch_label[s_idx][token_start + 1:token_end + 1] = args.label2id[f"I-{tag}"]
62 | batch_inputs['labels'] = torch.tensor(batch_label)
63 | return batch_inputs
64 |
65 | return DataLoader(dataset, batch_size=(batch_size if batch_size else args.batch_size), shuffle=shuffle,
66 | collate_fn=collote_fn)
--------------------------------------------------------------------------------
/src/trigger_detection/modeling.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torch.nn import CrossEntropyLoss
3 | from transformers import LongformerPreTrainedModel, LongformerModel
4 | from ..tools import LabelSmoothingCrossEntropy, FocalLoss, CRF
5 | from ..tools import FullyConnectedLayer
6 |
7 | class LongformerSoftmaxForTD(LongformerPreTrainedModel):
8 | def __init__(self, config, args):
9 | super().__init__(config)
10 | self.num_labels = args.num_labels
11 | self.longformer = LongformerModel(config, add_pooling_layer=False)
12 | self.use_ffnn_layer = args.use_ffnn_layer
13 | if self.use_ffnn_layer:
14 | self.ffnn_size = args.ffnn_size if args.ffnn_size != -1 else config.hidden_size
15 | self.mlp = FullyConnectedLayer(config, config.hidden_size, self.ffnn_size, config.hidden_dropout_prob)
16 | else:
17 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
18 | self.classifier = nn.Linear(self.ffnn_size if args.use_ffnn_layer else config.hidden_size, self.num_labels)
19 | self.loss_type = args.softmax_loss
20 | self.post_init()
21 |
22 | def forward(self, input_ids, attention_mask, labels=None):
23 | outputs = self.longformer(input_ids, attention_mask=attention_mask)
24 | sequence_output = outputs[0]
25 | if self.use_ffnn_layer:
26 | sequence_output = self.mlp(sequence_output)
27 | else:
28 | sequence_output = self.dropout(sequence_output)
29 | logits = self.classifier(sequence_output)
30 |
31 | loss = None
32 | if labels is not None:
33 | assert self.loss_type in ['lsr', 'focal', 'ce']
34 | if self.loss_type == 'lsr':
35 | loss_fct = LabelSmoothingCrossEntropy()
36 | elif self.loss_type == 'focal':
37 | loss_fct = FocalLoss()
38 | else:
39 | loss_fct = CrossEntropyLoss()
40 | # Only keep active parts of the loss
41 | if attention_mask is not None:
42 | active_loss = attention_mask.view(-1) == 1
43 | active_logits = logits.view(-1, self.num_labels)[active_loss]
44 | active_labels = labels.view(-1)[active_loss]
45 | loss = loss_fct(active_logits, active_labels)
46 | else:
47 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
48 | return loss, logits
49 |
50 | class LongformerCrfForTD(LongformerPreTrainedModel):
51 | def __init__(self, config, args):
52 | super().__init__(config)
53 | self.num_labels = args.num_labels
54 | self.longformer = LongformerModel(config, add_pooling_layer=False)
55 | self.use_ffnn_layer = args.use_ffnn_layer
56 | if self.use_ffnn_layer:
57 | self.ffnn_size = args.ffnn_size if args.ffnn_size != -1 else config.hidden_size
58 | self.mlp = FullyConnectedLayer(config, config.hidden_size, self.ffnn_size, config.hidden_dropout_prob)
59 | else:
60 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
61 | self.classifier = nn.Linear(self.ffnn_size if args.use_ffnn_layer else config.hidden_size, self.num_labels)
62 | self.crf = CRF(num_tags=self.num_labels, batch_first=True)
63 | self.post_init()
64 |
65 | def forward(self, input_ids, attention_mask, labels=None):
66 | outputs = self.longformer(input_ids, attention_mask=attention_mask)
67 | sequence_output = outputs[0]
68 | if self.use_ffnn_layer:
69 | sequence_output = self.mlp(sequence_output)
70 | else:
71 | sequence_output = self.dropout(sequence_output)
72 | logits = self.classifier(sequence_output)
73 |
74 | loss = None
75 | if labels is not None:
76 | loss = -1 * self.crf(emissions=logits, tags=labels, mask=attention_mask)
77 | return loss, logits
78 |
--------------------------------------------------------------------------------
/src/trigger_detection/run_td_crf.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import logging
4 | from tqdm.auto import tqdm
5 | import numpy as np
6 | import torch
7 | from transformers import AutoConfig, AutoTokenizer
8 | from transformers import AdamW, get_scheduler
9 | from seqeval.metrics import classification_report
10 | from seqeval.scheme import IOB2
11 | import sys
12 | sys.path.append('../../')
13 | from src.trigger_detection.data import KBPTrigger, get_dataLoader, CATEGORIES
14 | from src.trigger_detection.modeling import LongformerCrfForTD
15 | from src.trigger_detection.arg import parse_args
16 | from src.tools import seed_everything, NpEncoder
17 |
18 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
19 | datefmt='%Y/%m/%d %H:%M:%S',
20 | level=logging.INFO)
21 | logger = logging.getLogger("Model")
22 |
23 | def train_loop(args, dataloader, model, optimizer, lr_scheduler, epoch, total_loss):
24 | progress_bar = tqdm(range(len(dataloader)))
25 | progress_bar.set_description(f'loss: {0:>7f}')
26 | finish_step_num = epoch * len(dataloader)
27 |
28 | model.train()
29 | for step, batch_data in enumerate(dataloader, start=1):
30 | batch_data = batch_data.to(args.device)
31 | outputs = model(**batch_data)
32 | loss = outputs[0]
33 |
34 | optimizer.zero_grad()
35 | loss.backward()
36 | optimizer.step()
37 | lr_scheduler.step()
38 |
39 | total_loss += loss.item()
40 | progress_bar.set_description(f'loss: {total_loss/(finish_step_num + step):>7f}')
41 | progress_bar.update(1)
42 | return total_loss
43 |
44 | def test_loop(args, dataloader, model):
45 | true_labels, true_predictions = [], []
46 | model.eval()
47 | with torch.no_grad():
48 | for batch_data in tqdm(dataloader):
49 | batch_data = batch_data.to(args.device)
50 | outputs = model(**batch_data)
51 | logits = outputs[1]
52 | tags = model.crf.decode(logits, batch_data['attention_mask'])
53 | predictions = tags.squeeze(0).cpu().numpy()
54 | labels = batch_data['labels'].cpu().numpy()
55 | lens = np.sum(batch_data['attention_mask'].cpu().numpy(), axis=-1)
56 | true_labels += [
57 | [args.id2label[int(l)] for idx, l in enumerate(label) if idx > 0 and idx < seq_len - 1]
58 | for label, seq_len in zip(labels, lens)
59 | ]
60 | true_predictions += [
61 | [args.id2label[int(p)] for idx, p in enumerate(prediction) if idx > 0 and idx < seq_len - 1]
62 | for prediction, seq_len in zip(predictions, lens)
63 | ]
64 | return classification_report(true_labels, true_predictions, mode='strict', scheme=IOB2, output_dict=True)
65 |
66 | def train(args, train_dataset, dev_dataset, model, tokenizer):
67 | """ Train the model """
68 | train_dataloader = get_dataLoader(args, train_dataset, tokenizer, shuffle=True)
69 | dev_dataloader = get_dataLoader(args, dev_dataset, tokenizer, shuffle=False)
70 | t_total = len(train_dataloader) * args.num_train_epochs
71 | # Prepare optimizer and schedule (linear warmup and decay)
72 | no_decay = ["bias", "LayerNorm.weight"]
73 | longformer_param_optimizer = list(model.longformer.named_parameters())
74 | crf_param_optimizer = list(model.crf.named_parameters())
75 | linear_param_optimizer = list(model.classifier.named_parameters())
76 | optimizer_grouped_parameters = [
77 | {'params': [p for n, p in longformer_param_optimizer if not any(nd in n for nd in no_decay)],
78 | 'weight_decay': args.weight_decay, 'lr': args.learning_rate},
79 | {'params': [p for n, p in longformer_param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0,
80 | 'lr': args.learning_rate},
81 |
82 | {'params': [p for n, p in crf_param_optimizer if not any(nd in n for nd in no_decay)],
83 | 'weight_decay': args.weight_decay, 'lr': args.crf_learning_rate},
84 | {'params': [p for n, p in crf_param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0,
85 | 'lr': args.crf_learning_rate},
86 |
87 | {'params': [p for n, p in linear_param_optimizer if not any(nd in n for nd in no_decay)],
88 | 'weight_decay': args.weight_decay, 'lr': args.crf_learning_rate},
89 | {'params': [p for n, p in linear_param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0,
90 | 'lr': args.crf_learning_rate}
91 | ]
92 | args.warmup_steps = int(t_total * args.warmup_proportion)
93 | optimizer = AdamW(
94 | optimizer_grouped_parameters,
95 | lr=args.learning_rate,
96 | betas=(args.adam_beta1, args.adam_beta2),
97 | eps=args.adam_epsilon
98 | )
99 | lr_scheduler = get_scheduler(
100 | 'linear',
101 | optimizer,
102 | num_warmup_steps=args.warmup_steps,
103 | num_training_steps=t_total
104 | )
105 | # Train!
106 | logger.info("***** Running training *****")
107 | logger.info(f"Num examples - {len(train_dataset)}")
108 | logger.info(f"Num Epochs - {args.num_train_epochs}")
109 | logger.info(f"Total optimization steps - {t_total}")
110 | with open(os.path.join(args.output_dir, 'args.txt'), 'wt') as f:
111 | f.write(str(args))
112 |
113 | total_loss = 0.
114 | best_f1 = 0.
115 | for epoch in range(args.num_train_epochs):
116 | print(f"Epoch {epoch+1}/{args.num_train_epochs}\n-------------------------------")
117 | total_loss = train_loop(args, train_dataloader, model, optimizer, lr_scheduler, epoch, total_loss)
118 | metrics = test_loop(args, dev_dataloader, model)
119 | micro_f1, macro_f1 = metrics['micro avg']['f1-score'], metrics['macro avg']['f1-score']
120 | dev_f1 = metrics['weighted avg']['f1-score']
121 | logger.info(f'Dev: micro_F1 - {(100*micro_f1):0.4f} macro_f1 - {(100*macro_f1):0.4f} weighted_f1 - {(100*dev_f1):0.4f}')
122 | if dev_f1 > best_f1:
123 | best_f1 = dev_f1
124 | logger.info(f'saving new weights to {args.output_dir}...\n')
125 | save_weight = f'epoch_{epoch+1}_dev_f1_{(100*dev_f1):0.4f}_weights.bin'
126 | torch.save(model.state_dict(), os.path.join(args.output_dir, save_weight))
127 | with open(os.path.join(args.output_dir, 'dev_metrics.txt'), 'at') as f:
128 | f.write(f'epoch_{epoch+1}\n' + json.dumps(metrics, cls=NpEncoder) + '\n\n')
129 | logger.info("Done!")
130 |
131 | def predict(args, document:str, model, tokenizer):
132 | inputs = tokenizer(
133 | document,
134 | max_length=args.max_seq_length,
135 | truncation=True,
136 | return_tensors="pt",
137 | return_offsets_mapping=True
138 | )
139 | offsets = inputs.pop('offset_mapping').squeeze(0)
140 | inputs = inputs.to(args.device)
141 | with torch.no_grad():
142 | outputs = model(**inputs)
143 | logits = outputs[1]
144 | probabilities = torch.nn.functional.softmax(logits, dim=-1)[0].cpu().numpy().tolist()
145 | predictions = model.crf.decode(logits, inputs['attention_mask'])
146 | predictions = predictions.squeeze(0)[0].cpu().numpy().tolist()
147 |
148 | pred_label = []
149 | idx = 1
150 | while idx < len(predictions) - 1:
151 | pred = predictions[idx]
152 | label = args.id2label[pred]
153 | if label != "O":
154 | label = label[2:] # Remove the B- or I-
155 | start, end = offsets[idx]
156 | all_scores = [probabilities[idx][pred]]
157 | # Grab all the tokens labeled with I-label
158 | while (
159 | idx + 1 < len(predictions) - 1 and
160 | args.id2label[predictions[idx + 1]] == f"I-{label}"
161 | ):
162 | all_scores.append(probabilities[idx + 1][predictions[idx + 1]])
163 | _, end = offsets[idx + 1]
164 | idx += 1
165 |
166 | score = np.mean(all_scores).item()
167 | start, end = start.item(), end.item()
168 | word = document[start:end]
169 | pred_label.append({
170 | "trigger": word,
171 | "start": start,
172 | "subtype": label,
173 | "score": score
174 | })
175 | idx += 1
176 | return pred_label
177 |
178 | def test(args, test_dataset, model, tokenizer, save_weights:list):
179 | test_dataloader = get_dataLoader(args, test_dataset, tokenizer, batch_size=1, shuffle=False)
180 | logger.info('***** Running testing *****')
181 | for save_weight in save_weights:
182 | logger.info(f'loading weights from {save_weight}...')
183 | model.load_state_dict(torch.load(os.path.join(args.output_dir, save_weight)))
184 | metrics = test_loop(args, test_dataloader, model)
185 | with open(os.path.join(args.output_dir, 'test_metrics.txt'), 'at') as f:
186 | f.write(f'{save_weight}\n{json.dumps(metrics, cls=NpEncoder)}\n\n')
187 | if args.do_predict:
188 | logger.info(f'predicting labels of {save_weight}...')
189 | results = []
190 | model.eval()
191 | for sample in tqdm(test_dataset):
192 | pred_label = predict(args, sample['document'], model, tokenizer)
193 | results.append({
194 | "doc_id": sample['id'],
195 | "document": sample['document'],
196 | "pred_label": pred_label,
197 | "true_label": sample['tags']
198 | })
199 | with open(os.path.join(args.output_dir, save_weight + '_test_pred_events.json'), 'wt', encoding='utf-8') as f:
200 | for exapmle_result in results:
201 | f.write(json.dumps(exapmle_result) + '\n')
202 |
203 | if __name__ == '__main__':
204 | args = parse_args()
205 | if args.do_train and os.path.exists(args.output_dir) and os.listdir(args.output_dir):
206 | raise ValueError(
207 | f'Output directory ({args.output_dir}) already exists and is not empty.')
208 | if not os.path.exists(args.output_dir):
209 | os.mkdir(args.output_dir)
210 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
211 | args.n_gpu = torch.cuda.device_count()
212 | logger.warning(f'Using {args.device} device, n_gpu: {args.n_gpu}')
213 | # Set seed
214 | seed_everything(args.seed)
215 | # Prepare task
216 | args.id2label = {0:'O'}
217 | for c in CATEGORIES:
218 | args.id2label[len(args.id2label)] = f"B-{c}"
219 | args.id2label[len(args.id2label)] = f"I-{c}"
220 | args.label2id = {v: k for k, v in args.id2label.items()}
221 | args.num_labels = len(args.id2label)
222 | # Load pretrained model and tokenizer
223 | logger.info(f'loading pretrained model and tokenizer of {args.model_type} ...')
224 | config = AutoConfig.from_pretrained(
225 | args.model_checkpoint,
226 | cache_dir=args.cache_dir
227 | )
228 | tokenizer = AutoTokenizer.from_pretrained(args.model_checkpoint, cache_dir=args.cache_dir)
229 | model = LongformerCrfForTD.from_pretrained(
230 | args.model_checkpoint,
231 | config=config,
232 | cache_dir=args.cache_dir,
233 | args=args
234 | ).to(args.device)
235 | # Training
236 | if args.do_train:
237 | logger.info(f'Training/evaluation parameters: {args}')
238 | train_dataset = KBPTrigger(args.train_file)
239 | dev_dataset = KBPTrigger(args.dev_file)
240 | train(args, train_dataset, dev_dataset, model, tokenizer)
241 | # Testing
242 | save_weights = [file for file in os.listdir(args.output_dir) if file.endswith('.bin')]
243 | if args.do_test:
244 | test_dataset = KBPTrigger(args.test_file)
245 | test(args, test_dataset, model, tokenizer, save_weights)
246 | # Predicting
247 | if args.do_predict:
248 | for save_weight in save_weights:
249 | logger.info(f'loading weights from {save_weight}...')
250 | model.load_state_dict(torch.load(os.path.join(args.output_dir, save_weight)))
251 | logger.info(f'predicting labels of {save_weight}...')
252 |
253 | results = []
254 | model.eval()
255 | for sample in tqdm(test_dataset):
256 | pred_label = predict(args, sample['document'], model, tokenizer)
257 | results.append({
258 | "doc_id": sample['id'],
259 | "document": sample['document'],
260 | "pred_label": pred_label,
261 | "true_label": sample['tags']
262 | })
263 | with open(os.path.join(args.output_dir, save_weight + '_test_pred_events.json'), 'wt', encoding='utf-8') as f:
264 | for exapmle_result in results:
265 | f.write(json.dumps(exapmle_result) + '\n')
--------------------------------------------------------------------------------
/src/trigger_detection/run_td_crf.sh:
--------------------------------------------------------------------------------
1 | export OUTPUT_DIR=./crf_results/
2 |
3 | python3 run_td_crf.py \
4 | --output_dir=$OUTPUT_DIR \
5 | --model_type=longformer \
6 | --model_checkpoint=../../PT_MODELS/allenai/longformer-large-4096/ \
7 | --train_file=../../data/train_filtered.json \
8 | --dev_file=../../data/dev_filtered.json \
9 | --test_file=../../data/test_filtered.json \
10 | --max_seq_length=4096 \
11 | --learning_rate=1e-5 \
12 | --crf_learning_rate=5e-5 \
13 | --num_train_epochs=50 \
14 | --batch_size=1 \
15 | --do_train \
16 | --warmup_proportion=0. \
17 | --seed=42
--------------------------------------------------------------------------------
/src/trigger_detection/run_td_softmax.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import logging
4 | from tqdm.auto import tqdm
5 | import numpy as np
6 | import torch
7 | from transformers import AutoConfig, AutoTokenizer
8 | from transformers import AdamW, get_scheduler
9 | from seqeval.metrics import classification_report
10 | from seqeval.scheme import IOB2
11 | import sys
12 | sys.path.append('../../')
13 | from src.trigger_detection.data import KBPTrigger, get_dataLoader, CATEGORIES
14 | from src.trigger_detection.modeling import LongformerSoftmaxForTD
15 | from src.trigger_detection.arg import parse_args
16 | from src.tools import seed_everything, NpEncoder
17 |
18 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
19 | datefmt='%Y/%m/%d %H:%M:%S',
20 | level=logging.INFO)
21 | logger = logging.getLogger("Model")
22 |
23 | def train_loop(args, dataloader, model, optimizer, lr_scheduler, epoch, total_loss):
24 | progress_bar = tqdm(range(len(dataloader)))
25 | progress_bar.set_description(f'loss: {0:>7f}')
26 | finish_step_num = epoch * len(dataloader)
27 |
28 | model.train()
29 | for step, batch_data in enumerate(dataloader, start=1):
30 | batch_data = batch_data.to(args.device)
31 | outputs = model(**batch_data)
32 | loss = outputs[0]
33 |
34 | optimizer.zero_grad()
35 | loss.backward()
36 | optimizer.step()
37 | lr_scheduler.step()
38 |
39 | total_loss += loss.item()
40 | progress_bar.set_description(f'loss: {total_loss/(finish_step_num + step):>7f}')
41 | progress_bar.update(1)
42 | return total_loss
43 |
44 | def test_loop(args, dataloader, model):
45 | true_labels, true_predictions = [], []
46 | model.eval()
47 | with torch.no_grad():
48 | for batch_data in tqdm(dataloader):
49 | batch_data = batch_data.to(args.device)
50 | outputs = model(**batch_data)
51 | logits = outputs[1]
52 | predictions = logits.argmax(dim=-1).cpu().numpy() # [batch, seq]
53 | labels = batch_data['labels'].cpu().numpy()
54 | lens = np.sum(batch_data['attention_mask'].cpu().numpy(), axis=-1)
55 | true_labels += [
56 | [args.id2label[int(l)] for idx, l in enumerate(label) if idx > 0 and idx < seq_len - 1]
57 | for label, seq_len in zip(labels, lens)
58 | ]
59 | true_predictions += [
60 | [args.id2label[int(p)] for idx, p in enumerate(prediction) if idx > 0 and idx < seq_len - 1]
61 | for prediction, seq_len in zip(predictions, lens)
62 | ]
63 | return classification_report(true_labels, true_predictions, mode='strict', scheme=IOB2, output_dict=True)
64 |
65 | def train(args, train_dataset, dev_dataset, model, tokenizer):
66 | """ Train the model """
67 | train_dataloader = get_dataLoader(args, train_dataset, tokenizer, shuffle=True)
68 | dev_dataloader = get_dataLoader(args, dev_dataset, tokenizer, shuffle=False)
69 | t_total = len(train_dataloader) * args.num_train_epochs
70 | # Prepare optimizer and schedule (linear warmup and decay)
71 | no_decay = ["bias", "LayerNorm.weight"]
72 | optimizer_grouped_parameters = [
73 | {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay},
74 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}
75 | ]
76 | args.warmup_steps = int(t_total * args.warmup_proportion)
77 | optimizer = AdamW(
78 | optimizer_grouped_parameters,
79 | lr=args.learning_rate,
80 | betas=(args.adam_beta1, args.adam_beta2),
81 | eps=args.adam_epsilon
82 | )
83 | lr_scheduler = get_scheduler(
84 | 'linear',
85 | optimizer,
86 | num_warmup_steps=args.warmup_steps,
87 | num_training_steps=t_total
88 | )
89 | # Train!
90 | logger.info("***** Running training *****")
91 | logger.info(f"Num examples - {len(train_dataset)}")
92 | logger.info(f"Num Epochs - {args.num_train_epochs}")
93 | logger.info(f"Total optimization steps - {t_total}")
94 | with open(os.path.join(args.output_dir, 'args.txt'), 'wt') as f:
95 | f.write(str(args))
96 |
97 | total_loss = 0.
98 | best_f1 = 0.
99 | for epoch in range(args.num_train_epochs):
100 | print(f"Epoch {epoch+1}/{args.num_train_epochs}\n-------------------------------")
101 | total_loss = train_loop(args, train_dataloader, model, optimizer, lr_scheduler, epoch, total_loss)
102 | metrics = test_loop(args, dev_dataloader, model)
103 | micro_f1, macro_f1 = metrics['micro avg']['f1-score'], metrics['macro avg']['f1-score']
104 | dev_f1 = metrics['weighted avg']['f1-score']
105 | logger.info(f'Dev: micro_F1 - {(100*micro_f1):0.4f} macro_f1 - {(100*macro_f1):0.4f} weighted_f1 - {(100*dev_f1):0.4f}')
106 | if dev_f1 > best_f1:
107 | best_f1 = dev_f1
108 | logger.info(f'saving new weights to {args.output_dir}...\n')
109 | save_weight = f'epoch_{epoch+1}_dev_f1_{(100*dev_f1):0.4f}_weights.bin'
110 | torch.save(model.state_dict(), os.path.join(args.output_dir, save_weight))
111 | with open(os.path.join(args.output_dir, 'dev_metrics.txt'), 'at') as f:
112 | f.write(f'epoch_{epoch+1}\n' + json.dumps(metrics, cls=NpEncoder) + '\n\n')
113 | logger.info("Done!")
114 |
115 | def predict(args, document:str, model, tokenizer):
116 | inputs = tokenizer(
117 | document,
118 | max_length=args.max_seq_length,
119 | truncation=True,
120 | return_tensors="pt",
121 | return_offsets_mapping=True
122 | )
123 | offsets = inputs.pop('offset_mapping').squeeze(0)
124 | inputs = inputs.to(args.device)
125 | with torch.no_grad():
126 | outputs = model(**inputs)
127 | logits = outputs[1]
128 | probabilities = torch.nn.functional.softmax(logits, dim=-1)[0].cpu().numpy().tolist()
129 | predictions = logits.argmax(dim=-1)[0].cpu().numpy().tolist()
130 |
131 | pred_label = []
132 | idx = 1
133 | while idx < len(predictions) - 1:
134 | pred = predictions[idx]
135 | label = args.id2label[pred]
136 | if label != "O":
137 | label = label[2:] # Remove the B- or I-
138 | start, end = offsets[idx]
139 | all_scores = [probabilities[idx][pred]]
140 | # Grab all the tokens labeled with I-label
141 | while (
142 | idx + 1 < len(predictions) - 1 and
143 | args.id2label[predictions[idx + 1]] == f"I-{label}"
144 | ):
145 | all_scores.append(probabilities[idx + 1][predictions[idx + 1]])
146 | _, end = offsets[idx + 1]
147 | idx += 1
148 |
149 | score = np.mean(all_scores).item()
150 | start, end = start.item(), end.item()
151 | word = document[start:end]
152 | pred_label.append({
153 | "trigger": word,
154 | "start": start,
155 | "subtype": label,
156 | "score": score
157 | })
158 | idx += 1
159 | return pred_label
160 |
161 | def test(args, test_dataset, model, tokenizer, save_weights:list):
162 | test_dataloader = get_dataLoader(args, test_dataset, tokenizer, batch_size=1, shuffle=False)
163 | logger.info('***** Running testing *****')
164 | for save_weight in save_weights:
165 | logger.info(f'loading weights from {save_weight}...')
166 | model.load_state_dict(torch.load(os.path.join(args.output_dir, save_weight)))
167 | metrics = test_loop(args, test_dataloader, model)
168 | with open(os.path.join(args.output_dir, 'test_metrics.txt'), 'at') as f:
169 | f.write(f'{save_weight}\n{json.dumps(metrics, cls=NpEncoder)}\n\n')
170 |
171 | if __name__ == '__main__':
172 | args = parse_args()
173 | if args.do_train and os.path.exists(args.output_dir) and os.listdir(args.output_dir):
174 | raise ValueError(
175 | f'Output directory ({args.output_dir}) already exists and is not empty.')
176 | if not os.path.exists(args.output_dir):
177 | os.mkdir(args.output_dir)
178 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
179 | args.n_gpu = torch.cuda.device_count()
180 | logger.warning(f'Using {args.device} device, n_gpu: {args.n_gpu}')
181 | # Set seed
182 | seed_everything(args.seed)
183 | # Prepare task
184 | args.id2label = {0:'O'}
185 | for c in CATEGORIES:
186 | args.id2label[len(args.id2label)] = f"B-{c}"
187 | args.id2label[len(args.id2label)] = f"I-{c}"
188 | args.label2id = {v: k for k, v in args.id2label.items()}
189 | args.num_labels = len(args.id2label)
190 | # Load pretrained model and tokenizer
191 | logger.info(f'loading pretrained model and tokenizer of {args.model_type} ...')
192 | config = AutoConfig.from_pretrained(
193 | args.model_checkpoint,
194 | cache_dir=args.cache_dir
195 | )
196 | tokenizer = AutoTokenizer.from_pretrained(args.model_checkpoint, cache_dir=args.cache_dir)
197 | model = LongformerSoftmaxForTD.from_pretrained(
198 | args.model_checkpoint,
199 | config=config,
200 | cache_dir=args.cache_dir,
201 | args=args
202 | ).to(args.device)
203 | # Training
204 | if args.do_train:
205 | train_dataset = KBPTrigger(args.train_file)
206 | dev_dataset = KBPTrigger(args.dev_file)
207 | train(args, train_dataset, dev_dataset, model, tokenizer)
208 | # Testing
209 | save_weights = [file for file in os.listdir(args.output_dir) if file.endswith('.bin')]
210 | if args.do_test:
211 | test_dataset = KBPTrigger(args.test_file)
212 | test(args, test_dataset, model, tokenizer, save_weights)
213 | # Predicting
214 | if args.do_predict:
215 | for save_weight in save_weights:
216 | logger.info(f'loading weights from {save_weight}...')
217 | model.load_state_dict(torch.load(os.path.join(args.output_dir, save_weight)))
218 | logger.info(f'predicting labels of {save_weight}...')
219 |
220 | results = []
221 | model.eval()
222 | for sample in tqdm(test_dataset):
223 | pred_label = predict(args, sample['document'], model, tokenizer)
224 | results.append({
225 | "doc_id": sample['id'],
226 | "document": sample['document'],
227 | "pred_label": pred_label,
228 | "true_label": sample['tags']
229 | })
230 | with open(os.path.join(args.output_dir, save_weight + '_test_pred_events.json'), 'wt', encoding='utf-8') as f:
231 | for exapmle_result in results:
232 | f.write(json.dumps(exapmle_result) + '\n')
233 |
--------------------------------------------------------------------------------
/src/trigger_detection/run_td_softmax.sh:
--------------------------------------------------------------------------------
1 | export OUTPUT_DIR=./softmax_ce_results/
2 |
3 | python3 run_td_softmax.py \
4 | --output_dir=$OUTPUT_DIR \
5 | --model_type=longformer \
6 | --model_checkpoint=../../PT_MODELS/allenai/longformer-large-4096/ \
7 | --train_file=../../data/train_filtered.json \
8 | --dev_file=../../data/dev_filtered.json \
9 | --test_file=../../data/test_filtered.json \
10 | --max_seq_length=4096 \
11 | --learning_rate=1e-5 \
12 | --softmax_loss=ce \
13 | --num_train_epochs=50 \
14 | --batch_size=1 \
15 | --do_train \
16 | --warmup_proportion=0. \
17 | --seed=42
--------------------------------------------------------------------------------