├── src ├── __init__.py ├── models │ ├── __init__.py │ ├── roberta.py │ ├── albert.py │ ├── distilbert.py │ ├── bert.py │ └── xlnet.py ├── fincausal_evaluation │ ├── __init__.py │ └── task2_evaluate.py ├── logging.py ├── config.py ├── data.py ├── training.py ├── preprocessing.py └── evaluation.py ├── utils ├── __init__.py └── split_dataset.py ├── output └── README.md ├── data └── README.md ├── .gitignore ├── README.md ├── requirements.txt ├── main.py └── LICENSE /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/fincausal_evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /output/README.md: -------------------------------------------------------------------------------- 1 | The model outputs are stored in this repository. -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | Place data in this repository: 2 | - Trial dataset (`fnp2020-fincausal-task2.csv`) 3 | - Practice dataset (`fnp2020-fincausal2-task2.csv`) 4 | - Test dataset (`task2.csv`) -------------------------------------------------------------------------------- /utils/split_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import pandas as pd 3 | from pathlib import Path 4 | import os 5 | import sys 6 | from sklearn.model_selection import train_test_split 7 | 8 | if __name__ == '__main__': 9 | 10 | fincausal_data_path = Path(os.environ.get('FINCAUSAL_DATA_PATH', 11 | os.path.dirname(os.path.realpath(sys.argv[0])) + '/../data')) 12 | 13 | input_file = fincausal_data_path / "fnp2020-fincausal-task2.csv" 14 | train_output = fincausal_data_path / "fnp2020-train.csv" 15 | dev_output = fincausal_data_path / "fnp2020-eval.csv" 16 | size = 0.1 17 | seed = 42 18 | 19 | data = pd.read_csv(input_file, delimiter=';', header=0) 20 | 21 | train, test = train_test_split(data, test_size=size, random_state=seed) 22 | 23 | train.to_csv(train_output, header=True, sep=';', index=None) 24 | test.to_csv(dev_output, header=True, sep=';', index=None) 25 | -------------------------------------------------------------------------------- /src/logging.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020, Guillaume Becquin 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Dict 16 | 17 | from src.config import ModelConfigurations, RunConfig 18 | 19 | 20 | def initialize_log_dict(model_config: ModelConfigurations, 21 | run_config: RunConfig, 22 | model_tokenizer_mapping: Dict) -> Dict: 23 | model_type = model_config.value[0] 24 | model_class, tokenizer_class = model_tokenizer_mapping[model_type] 25 | 26 | return {'MODEL_TYPE': model_config.value[0], 27 | 'MODEL_CLASS': model_class.__name__, 28 | 'TOKENIZER_CLASS': tokenizer_class.__name__, 29 | 'MODEL_NAME_OR_PATH': model_config.value[1], 30 | 'DO_TRAIN': run_config.do_train, 31 | 'DO_EVAL': run_config.do_eval, 32 | 'DO_LOWER_CASE': model_config.value[2], 33 | 'MAX_SEQ_LENGTH': run_config.max_seq_length, 34 | 'DOC_STRIDE': run_config.doc_stride, 35 | 'PER_GPU_BATCH_SIZE': run_config.train_batch_size, 36 | 'GRADIENT_ACCUMULATION_STEPS': run_config.gradient_accumulation_steps, 37 | 'WARMUP_STEPS': run_config.warmup_steps, 38 | 'LEARNING_RATE': run_config.learning_rate, 39 | 'NUM_TRAIN_EPOCHS': run_config.num_train_epochs, 40 | 'PER_GPU_EVAL_BATCH_SIZE': run_config.eval_batch_size, 41 | 'N_BEST_SIZE': run_config.n_best_size, 42 | 'MAX_ANSWER_LENGTH': run_config.max_answer_length, 43 | 'SENTENCE_BOUNDARY_HEURISTIC': run_config.sentence_boundary_heuristic, 44 | 'FULL_SENTENCE_HEURISTIC': run_config.full_sentence_heuristic, 45 | 'SHARED_SENTENCE_HEURISTIC': run_config.shared_sentence_heuristic, 46 | 'OPTIMIZER': str(run_config.optimizer_class), 47 | 'WEIGHT_DECAY': run_config.weight_decay, 48 | 'SCHEDULER_FUNCTION': str(run_config.scheduler_function) 49 | } 50 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | /BACKUP/ 131 | .idea/ 132 | /data/*.csv 133 | /output/** -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Span-based Causality Extraction for Financial Documents 2 | This repository contains the supporting code for the FinCausal 2020 submission titled `Span-based Causality Extraction for Financial Documents`. The model extracts cause and effect spans from financial documents, for example: 3 | 4 | |Text|Cause|Effect| 5 | | ------------- |:-----------------------:| ---------------------:| 6 | |Boussard Gavaudan Investment Management LLP bought a new position in shares of GENFIT S A/ADR in the second quarter worth about $199,000. Morgan Stanley increased its stake in shares of GENFIT S A/ADR by 24.4% in the second quarter.Morgan Stanley now owns 10,700 shares of the company’s stock worth $211,000 after purchasing an additional 2,100 shares during the period|Morgan Stanley increased its stake in shares of GENFIT S A/ADR by 24.4% in the second quarter|Morgan Stanley now owns 10,700 shares of the company’s stock worth $211,000 after purchasing an additional 2,100 shares during the period.| 7 | |Zhao found himself 60 million yuan indebted after losing 9,000 BTC in a single day (February 10, 2014)|losing 9,000 BTC in a single day (February 10, 2014)|Zhao found himself 60 million yuan indebted| 8 | 9 | (sample from the [task description: Data Processing and Metrics for FinCausal Shared Task, 2020, Mariko et al.](https://drive.google.com/file/d/1LUTJVj9ItJMZzKvy1LrCTuBK2SITzr1z/view)) 10 | 11 | The system ranked 2nd on the official evaluation board, and reached the following performance in post-evaluation: 12 | 13 | #### Post-evaluation performance 14 | 15 | |Metric|score| 16 | |:-------------|:-------------:| 17 | |weighted-averaged F1|95.01%| 18 | |Exact matches| 83.34%| 19 | |weighted-averaged Precision| 95.01%| 20 | |weighted-averaged Recall| 95.01%| 21 | 22 | #### Official evaluation performance 23 | 24 | |Metric|score| 25 | |:-------------|:-------------:| 26 | |weighted-averaged F1|94.66%| 27 | |Exact matches| 73.66%| 28 | |weighted-averaged Precision| 94.66%| 29 | |weighted-averaged Recall| 94.66%| 30 | 31 | The system is based on a RoBERTa span-extraction model (similar to Question Answering architecture), a full description of the system is available in the related system description. If you find this system useful, please cite us: 32 | 33 | ```latex 34 | @inproceedings{ 35 | Becquin-fincausal-2020, 36 | title ={{GBe at FinCausal 2020, Task 2: Span-based Causality Extraction for Financial Documents}}, 37 | author = {Becquin, Guillaume}, 38 | booktitle ={{The 1st Joint Workshop on Financial Narrative Processing and MultiLing Financial Summarisation (FNP-FNS 2020}}, 39 | year = {2020}, 40 | address = {Barcelona, Spain} 41 | } 42 | ``` 43 | 44 | ## Instructions 45 | 46 | 0. Install requirements provided in `reuiquirements.py` (it is advised to use a virtual environment) 47 | 1. Generate the train / development data split running running the `./utils/split_dataset.py` 48 | 49 | ### Training: 50 | 2. run `main.py --train` 51 | 52 | ### Evaluation: 53 | 2. run `main.py --eval` 54 | 55 | ### Generate predictions: 56 | 2. run `main.py --test` -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Package Version Location 2 | ---------------------- ------------ -------------------------- 3 | absl-py 0.9.0 4 | backcall 0.1.0 5 | blis 0.4.1 6 | boto3 1.12.39 7 | botocore 1.15.39 8 | cachetools 4.1.0 9 | catalogue 1.0.0 10 | certifi 2020.4.5.1 11 | chardet 3.0.4 12 | click 7.1.1 13 | colorama 0.4.3 14 | cycler 0.10.0 15 | cymem 2.0.3 16 | decorator 4.4.2 17 | docutils 0.15.2 18 | en-core-web-lg 2.2.5 19 | filelock 3.0.12 20 | funcy 1.14 21 | google-auth 1.13.1 22 | google-auth-oauthlib 0.4.1 23 | grpcio 1.28.1 24 | idna 2.9 25 | importlib-metadata 1.6.1 26 | ipykernel 5.3.0 27 | ipython 7.15.0 28 | ipython-genutils 0.2.0 29 | jedi 0.17.0 30 | jmespath 0.9.5 31 | joblib 0.14.1 32 | jupyter-client 6.1.3 33 | jupyter-core 4.6.3 34 | kiwisolver 1.2.0 35 | Library 0.0.0 36 | Markdown 3.2.1 37 | matplotlib 3.2.2 38 | murmurhash 1.0.2 39 | nltk 3.5 40 | numpy 1.18.2 41 | oauthlib 3.1.0 42 | packaging 20.4 43 | pandas 1.0.3 44 | parso 0.7.0 45 | pickleshare 0.7.5 46 | pip 20.0.2 47 | plac 1.1.3 48 | preshed 3.0.2 49 | prompt-toolkit 3.0.5 50 | protobuf 3.11.3 51 | pyasn1 0.4.8 52 | pyasn1-modules 0.2.8 53 | Pygments 2.6.1 54 | pyparsing 2.4.7 55 | pysbd 0.2.3 56 | python-crfsuite 0.9.7 57 | python-dateutil 2.8.1 58 | pytorch-ranger 0.1.1 59 | pytz 2019.3 60 | pywin32 227 61 | pyzmq 19.0.1 62 | regex 2020.4.4 63 | requests 2.23.0 64 | requests-oauthlib 1.3.0 65 | rsa 4.0 66 | s3transfer 0.3.3 67 | sacremoses 0.0.38 68 | scikit-learn 0.22.2.post1 69 | scipy 1.4.1 70 | seaborn 0.10.1 71 | sentencepiece 0.1.85 72 | setuptools 46.1.3 73 | six 1.14.0 74 | sklearn 0.0 75 | spacy 2.2.4 76 | srsly 1.0.2 77 | tensorboard 2.2.0 78 | tensorboard-plugin-wit 1.6.0.post3 79 | thinc 7.4.0 80 | tokenizers 0.7.0rc3 81 | torch 1.4.0 82 | torch-optimizer 0.0.1a12 83 | tornado 6.0.4 84 | tqdm 4.45.0 85 | traitlets 4.3.3 86 | transformers 3.0.1 e:\coding\transformers\src 87 | urllib3 1.25.8 88 | wasabi 0.6.0 89 | wcwidth 0.2.3 90 | Werkzeug 1.0.1 91 | wheel 0.34.2 92 | zipp 3.1.0 93 | -------------------------------------------------------------------------------- /src/models/roberta.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # Copyright 2020 Guillaume Becquin. 4 | # MODIFIED FOR CAUSE EFFECT EXTRACTION 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | from torch import nn 19 | from transformers import BertPreTrainedModel, RobertaModel 20 | 21 | 22 | class RoBERTaForCauseEffect(BertPreTrainedModel): 23 | def __init__(self, config): 24 | super().__init__(config) 25 | 26 | self.roberta = RobertaModel(config) 27 | self.cause_outputs = nn.Linear(config.hidden_size, config.num_labels) 28 | self.effect_outputs = nn.Linear(config.hidden_size, config.num_labels) 29 | assert config.num_labels == 2 30 | self.init_weights() 31 | 32 | def forward( 33 | self, 34 | input_ids=None, 35 | attention_mask=None, 36 | token_type_ids=None, 37 | position_ids=None, 38 | head_mask=None, 39 | inputs_embeds=None, 40 | start_cause_positions=None, 41 | end_cause_positions=None, 42 | start_effect_positions=None, 43 | end_effect_positions=None, 44 | ): 45 | bert_output = self.roberta( 46 | input_ids=input_ids, 47 | attention_mask=attention_mask, 48 | token_type_ids=token_type_ids, 49 | position_ids=position_ids, 50 | head_mask=head_mask, 51 | inputs_embeds=inputs_embeds 52 | ) 53 | hidden_states = bert_output[0] # (bs, max_query_len, dim) 54 | cause_logits = self.cause_outputs(hidden_states) # (bs, max_query_len, 2) 55 | effect_logits = self.effect_outputs(hidden_states) # (bs, max_query_len, 2) 56 | start_cause_logits, end_cause_logits = cause_logits.split(1, dim=-1) 57 | start_effect_logits, end_effect_logits = effect_logits.split(1, dim=-1) 58 | 59 | start_cause_logits = start_cause_logits.squeeze(-1) # (bs, max_query_len) 60 | end_cause_logits = end_cause_logits.squeeze(-1) # (bs, max_query_len) 61 | start_effect_logits = start_effect_logits.squeeze(-1) # (bs, max_query_len) 62 | end_effect_logits = end_effect_logits.squeeze(-1) # (bs, max_query_len) 63 | 64 | outputs = (start_cause_logits, end_cause_logits, start_effect_logits, end_effect_logits,) \ 65 | + bert_output[2:] 66 | if start_cause_positions is not None \ 67 | and end_cause_positions is not None \ 68 | and start_effect_positions is not None \ 69 | and end_effect_positions is not None: 70 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 71 | ignored_index = start_cause_logits.size(1) 72 | start_cause_positions.clamp_(0, ignored_index) 73 | end_cause_positions.clamp_(0, ignored_index) 74 | start_effect_positions.clamp_(0, ignored_index) 75 | end_effect_positions.clamp_(0, ignored_index) 76 | 77 | loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) 78 | start_cause_loss = loss_fct(start_cause_logits, start_cause_positions) 79 | end_cause_loss = loss_fct(end_cause_logits, end_cause_positions) 80 | start_effect_loss = loss_fct(start_effect_logits, start_effect_positions) 81 | end_effect_loss = loss_fct(end_effect_logits, end_effect_positions) 82 | total_loss = (start_cause_loss + end_cause_loss + start_effect_loss + end_effect_loss) / 4 83 | outputs = (total_loss,) + outputs 84 | 85 | return outputs 86 | -------------------------------------------------------------------------------- /src/models/albert.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team. 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # Copyright 2020 Guillaume Becquin. 4 | # MODIFIED FOR CAUSE EFFECT EXTRACTION 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | 19 | from torch import nn 20 | from transformers import AlbertModel, AlbertPreTrainedModel 21 | 22 | 23 | class AlbertForCauseEffect(AlbertPreTrainedModel): 24 | def __init__(self, config): 25 | super(AlbertForCauseEffect, self).__init__(config) 26 | 27 | self.albert = AlbertModel(config) 28 | self.cause_outputs = nn.Linear(config.hidden_size, config.num_labels) 29 | self.effect_outputs = nn.Linear(config.hidden_size, config.num_labels) 30 | assert config.num_labels == 2 31 | self.init_weights() 32 | 33 | def forward( 34 | self, 35 | input_ids=None, 36 | attention_mask=None, 37 | token_type_ids=None, 38 | position_ids=None, 39 | head_mask=None, 40 | inputs_embeds=None, 41 | start_cause_positions=None, 42 | end_cause_positions=None, 43 | start_effect_positions=None, 44 | end_effect_positions=None, 45 | ): 46 | albert_output = self.albert( 47 | input_ids=input_ids, 48 | attention_mask=attention_mask, 49 | token_type_ids=token_type_ids, 50 | position_ids=position_ids, 51 | head_mask=head_mask, 52 | inputs_embeds=inputs_embeds 53 | ) 54 | hidden_states = albert_output[0] # (bs, max_query_len, dim) 55 | cause_logits = self.cause_outputs(hidden_states) # (bs, max_query_len, 2) 56 | effect_logits = self.effect_outputs(hidden_states) # (bs, max_query_len, 2) 57 | start_cause_logits, end_cause_logits = cause_logits.split(1, dim=-1) 58 | start_effect_logits, end_effect_logits = effect_logits.split(1, dim=-1) 59 | 60 | start_cause_logits = start_cause_logits.squeeze(-1) # (bs, max_query_len) 61 | end_cause_logits = end_cause_logits.squeeze(-1) # (bs, max_query_len) 62 | start_effect_logits = start_effect_logits.squeeze(-1) # (bs, max_query_len) 63 | end_effect_logits = end_effect_logits.squeeze(-1) # (bs, max_query_len) 64 | 65 | outputs = (start_cause_logits, end_cause_logits, start_effect_logits, end_effect_logits,) \ 66 | + albert_output[2:] 67 | if start_cause_positions is not None \ 68 | and end_cause_positions is not None \ 69 | and start_effect_positions is not None \ 70 | and end_effect_positions is not None: 71 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 72 | ignored_index = start_cause_logits.size(1) 73 | start_cause_positions.clamp_(0, ignored_index) 74 | end_cause_positions.clamp_(0, ignored_index) 75 | start_effect_positions.clamp_(0, ignored_index) 76 | end_effect_positions.clamp_(0, ignored_index) 77 | 78 | loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) 79 | start_cause_loss = loss_fct(start_cause_logits, start_cause_positions) 80 | end_cause_loss = loss_fct(end_cause_logits, end_cause_positions) 81 | start_effect_loss = loss_fct(start_effect_logits, start_effect_positions) 82 | end_effect_loss = loss_fct(end_effect_logits, end_effect_positions) 83 | total_loss = (start_cause_loss + end_cause_loss + start_effect_loss + end_effect_loss) / 4 84 | outputs = (total_loss,) + outputs 85 | 86 | return outputs 87 | -------------------------------------------------------------------------------- /src/models/distilbert.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # Copyright 2020 Guillaume Becquin. 4 | # MODIFIED FOR CAUSE EFFECT EXTRACTION 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | from torch import nn 19 | from transformers import DistilBertPreTrainedModel, DistilBertModel 20 | 21 | 22 | class DistilBertForCauseEffect(DistilBertPreTrainedModel): 23 | def __init__(self, config): 24 | super().__init__(config) 25 | 26 | self.distilbert = DistilBertModel(config) 27 | self.cause_outputs = nn.Linear(config.dim, config.num_labels) 28 | self.effect_outputs = nn.Linear(config.dim, config.num_labels) 29 | assert config.num_labels == 2 30 | self.dropout = nn.Dropout(config.qa_dropout) 31 | self.init_weights() 32 | 33 | def forward( 34 | self, 35 | input_ids=None, 36 | attention_mask=None, 37 | head_mask=None, 38 | inputs_embeds=None, 39 | start_cause_positions=None, 40 | end_cause_positions=None, 41 | start_effect_positions=None, 42 | end_effect_positions=None, 43 | ): 44 | distilbert_output = self.distilbert( 45 | input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds 46 | ) 47 | hidden_states = distilbert_output[0] # (bs, max_query_len, dim) 48 | 49 | hidden_states = self.dropout(hidden_states) # (bs, max_query_len, dim) 50 | cause_logits = self.cause_outputs(hidden_states) # (bs, max_query_len, 2) 51 | effect_logits = self.effect_outputs(hidden_states) # (bs, max_query_len, 2) 52 | start_cause_logits, end_cause_logits = cause_logits.split(1, dim=-1) 53 | start_effect_logits, end_effect_logits = effect_logits.split(1, dim=-1) 54 | 55 | start_cause_logits = start_cause_logits.squeeze(-1) # (bs, max_query_len) 56 | end_cause_logits = end_cause_logits.squeeze(-1) # (bs, max_query_len) 57 | start_effect_logits = start_effect_logits.squeeze(-1) # (bs, max_query_len) 58 | end_effect_logits = end_effect_logits.squeeze(-1) # (bs, max_query_len) 59 | 60 | outputs = (start_cause_logits, end_cause_logits, start_effect_logits, end_effect_logits,) \ 61 | + distilbert_output[1:] 62 | if start_cause_positions is not None \ 63 | and end_cause_positions is not None \ 64 | and start_effect_positions is not None \ 65 | and end_effect_positions is not None: 66 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 67 | ignored_index = start_cause_logits.size(1) 68 | start_cause_positions.clamp_(0, ignored_index) 69 | end_cause_positions.clamp_(0, ignored_index) 70 | start_effect_positions.clamp_(0, ignored_index) 71 | end_effect_positions.clamp_(0, ignored_index) 72 | 73 | loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) 74 | start_cause_loss = loss_fct(start_cause_logits, start_cause_positions) 75 | end_cause_loss = loss_fct(end_cause_logits, end_cause_positions) 76 | start_effect_loss = loss_fct(start_effect_logits, start_effect_positions) 77 | end_effect_loss = loss_fct(end_effect_logits, end_effect_positions) 78 | total_loss = (start_cause_loss + end_cause_loss + start_effect_loss + end_effect_loss) / 4 79 | outputs = (total_loss,) + outputs 80 | 81 | return outputs # (loss), start_cause_logits, end_cause_logits, start_effect_logits, end_effect_logits(hidden_states), (attentions) 82 | -------------------------------------------------------------------------------- /src/models/bert.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # Copyright 2020 Guillaume Becquin. 4 | # MODIFIED FOR CAUSE EFFECT EXTRACTION 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | from torch import nn 19 | from transformers import BertModel, BertPreTrainedModel 20 | 21 | 22 | class BertForCauseEffect(BertPreTrainedModel): 23 | def __init__(self, config): 24 | super(BertForCauseEffect, self).__init__(config) 25 | 26 | self.bert = BertModel(config) 27 | self.cause_outputs = nn.Linear(config.hidden_size, config.num_labels) 28 | self.effect_outputs = nn.Linear(config.hidden_size, config.num_labels) 29 | assert config.num_labels == 2 30 | self.init_weights() 31 | 32 | def forward( 33 | self, 34 | input_ids=None, 35 | attention_mask=None, 36 | token_type_ids=None, 37 | position_ids=None, 38 | head_mask=None, 39 | inputs_embeds=None, 40 | start_cause_positions=None, 41 | end_cause_positions=None, 42 | start_effect_positions=None, 43 | end_effect_positions=None, 44 | ): 45 | bert_output = self.bert( 46 | input_ids=input_ids, 47 | attention_mask=attention_mask, 48 | token_type_ids=token_type_ids, 49 | position_ids=position_ids, 50 | head_mask=head_mask, 51 | inputs_embeds=inputs_embeds 52 | ) 53 | hidden_states = bert_output[0] # (bs, max_query_len, dim) 54 | cause_logits = self.cause_outputs(hidden_states) # (bs, max_query_len, 2) 55 | effect_logits = self.effect_outputs(hidden_states) # (bs, max_query_len, 2) 56 | start_cause_logits, end_cause_logits = cause_logits.split(1, dim=-1) 57 | start_effect_logits, end_effect_logits = effect_logits.split(1, dim=-1) 58 | 59 | start_cause_logits = start_cause_logits.squeeze(-1) # (bs, max_query_len) 60 | end_cause_logits = end_cause_logits.squeeze(-1) # (bs, max_query_len) 61 | start_effect_logits = start_effect_logits.squeeze(-1) # (bs, max_query_len) 62 | end_effect_logits = end_effect_logits.squeeze(-1) # (bs, max_query_len) 63 | 64 | outputs = (start_cause_logits, end_cause_logits, start_effect_logits, end_effect_logits,) \ 65 | + bert_output[2:] 66 | if start_cause_positions is not None \ 67 | and end_cause_positions is not None \ 68 | and start_effect_positions is not None \ 69 | and end_effect_positions is not None: 70 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 71 | ignored_index = start_cause_logits.size(1) 72 | start_cause_positions.clamp_(0, ignored_index) 73 | end_cause_positions.clamp_(0, ignored_index) 74 | start_effect_positions.clamp_(0, ignored_index) 75 | end_effect_positions.clamp_(0, ignored_index) 76 | 77 | loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) 78 | start_cause_loss = loss_fct(start_cause_logits, start_cause_positions) 79 | end_cause_loss = loss_fct(end_cause_logits, end_cause_positions) 80 | start_effect_loss = loss_fct(start_effect_logits, start_effect_positions) 81 | end_effect_loss = loss_fct(end_effect_logits, end_effect_positions) 82 | total_loss = (start_cause_loss + end_cause_loss + start_effect_loss + end_effect_loss) / 4 83 | outputs = (total_loss,) + outputs 84 | 85 | return outputs # (loss), start_cause_logits, end_cause_logits, start_effect_logits, end_effect_logits(hidden_states), (attentions) 86 | -------------------------------------------------------------------------------- /src/models/xlnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # Copyright 2020 Guillaume Becquin. 4 | # MODIFIED FOR CAUSE EFFECT EXTRACTION 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | from torch import nn 19 | from torch.nn import CrossEntropyLoss 20 | from transformers import XLNetPreTrainedModel, XLNetModel 21 | 22 | 23 | class XLNetForCauseEffect(XLNetPreTrainedModel): 24 | def __init__(self, config): 25 | super().__init__(config) 26 | self.num_labels = config.num_labels 27 | 28 | self.transformer = XLNetModel(config) 29 | self.cause_outputs = nn.Linear(config.hidden_size, config.num_labels) 30 | self.effect_outputs = nn.Linear(config.hidden_size, config.num_labels) 31 | self.init_weights() 32 | 33 | def forward( 34 | self, 35 | input_ids=None, 36 | attention_mask=None, 37 | mems=None, 38 | perm_mask=None, 39 | target_mapping=None, 40 | token_type_ids=None, 41 | input_mask=None, 42 | head_mask=None, 43 | inputs_embeds=None, 44 | use_cache=True, 45 | start_cause_positions=None, 46 | end_cause_positions=None, 47 | start_effect_positions=None, 48 | end_effect_positions=None, 49 | ): 50 | 51 | outputs = self.transformer( 52 | input_ids, 53 | attention_mask=attention_mask, 54 | mems=mems, 55 | perm_mask=perm_mask, 56 | target_mapping=target_mapping, 57 | token_type_ids=token_type_ids, 58 | input_mask=input_mask, 59 | head_mask=head_mask, 60 | inputs_embeds=inputs_embeds, 61 | use_cache=use_cache, 62 | ) 63 | 64 | sequence_output = outputs[0] 65 | 66 | cause_logits = self.cause_outputs(sequence_output) 67 | start_cause_logits, end_cause_logits = cause_logits.split(1, dim=-1) 68 | start_cause_logits = start_cause_logits.squeeze(-1) 69 | end_cause_logits = end_cause_logits.squeeze(-1) 70 | 71 | effect_logits = self.effect_outputs(sequence_output) 72 | start_effect_logits, end_effect_logits = effect_logits.split(1, dim=-1) 73 | start_effect_logits = start_effect_logits.squeeze(-1) 74 | end_effect_logits = end_effect_logits.squeeze(-1) 75 | 76 | outputs = (start_cause_logits, end_cause_logits, start_effect_logits, end_effect_logits,) + outputs[2:] 77 | if start_cause_positions is not None \ 78 | and end_cause_positions is not None \ 79 | and start_effect_positions is not None \ 80 | and end_effect_positions is not None: 81 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 82 | ignored_index = start_cause_logits.size(1) 83 | start_cause_positions.clamp_(0, ignored_index) 84 | end_cause_positions.clamp_(0, ignored_index) 85 | start_effect_positions.clamp_(0, ignored_index) 86 | end_effect_positions.clamp_(0, ignored_index) 87 | 88 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 89 | start_cause_loss = loss_fct(start_cause_logits, start_cause_positions) 90 | end_cause_loss = loss_fct(end_cause_logits, end_cause_positions) 91 | start_effect_loss = loss_fct(start_effect_logits, start_effect_positions) 92 | end_effect_loss = loss_fct(end_effect_logits, end_effect_positions) 93 | total_loss = (start_cause_loss + end_cause_loss + start_effect_loss + end_effect_loss) / 4 94 | outputs = (total_loss,) + outputs 95 | 96 | return outputs 97 | -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020, Guillaume Becquin 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from enum import Enum 16 | from typing import Callable 17 | 18 | import torch 19 | from transformers import AdamW, get_cosine_schedule_with_warmup, AlbertTokenizer, XLNetTokenizer, RobertaTokenizer, \ 20 | BertTokenizer, DistilBertTokenizer 21 | 22 | from src.models.albert import AlbertForCauseEffect 23 | from src.models.bert import BertForCauseEffect 24 | from src.models.distilbert import DistilBertForCauseEffect 25 | from src.models.roberta import RoBERTaForCauseEffect 26 | from src.models.xlnet import XLNetForCauseEffect 27 | 28 | 29 | class ModelConfigurations(Enum): 30 | BertBase = ('bert', 'bert-base-cased', False) 31 | BertLarge = ('bert', 'bert-large-cased', False) 32 | BertSquad = ('bert', 'deepset/bert-base-cased-squad2', False) 33 | BertSquad2 = ('bert', 'deepset/bert-large-uncased-whole-word-masking-squad2', True) 34 | DistilBert = ('distilbert', 'distilbert-base-uncased', True) 35 | DistilBertSquad = ('distilbert', 'distilbert-base-uncased-distilled-squad', True) 36 | RoBERTaSquad = ('roberta', 'deepset/roberta-base-squad2', False) 37 | RoBERTaSquadLarge = ('roberta', 'ahotrod/roberta_large_squad2', False) 38 | RoBERTa = ('roberta', 'roberta-base', False) 39 | RoBERTaLarge = ('roberta', 'roberta-large', False) 40 | XLNetBase = ('xlnet', 'xlnet-base-cased', False) 41 | AlbertSquad = ('albert', 'twmkn9/albert-base-v2-squad2', True) 42 | 43 | 44 | model_tokenizer_mapping = { 45 | 'distilbert': (DistilBertForCauseEffect, DistilBertTokenizer), 46 | 'bert': (BertForCauseEffect, BertTokenizer), 47 | 'roberta': (RoBERTaForCauseEffect, RobertaTokenizer), 48 | 'xlnet': (XLNetForCauseEffect, XLNetTokenizer), 49 | 'albert': (AlbertForCauseEffect, AlbertTokenizer), 50 | } 51 | 52 | 53 | class RunConfig: 54 | def __init__(self, 55 | do_train: bool = False, 56 | do_eval: bool = True, 57 | do_test: bool = False, 58 | max_seq_length: int = 384, 59 | doc_stride: int = 128, 60 | train_batch_size: int = 4, 61 | gradient_accumulation_steps: int = 3, 62 | warmup_steps: int = 50, 63 | learning_rate: float = 3e-5, 64 | differential_lr_ratio: float = 1.0, 65 | max_grad_norm: float = 1.0, 66 | adam_epsilon: float = 1e-8, 67 | num_train_epochs: int = 5, 68 | save_model: bool = True, 69 | weight_decay: float = 0.0, 70 | optimizer_class: torch.optim.Optimizer = AdamW, 71 | scheduler_function: Callable = get_cosine_schedule_with_warmup, 72 | evaluate_during_training: bool = True, 73 | eval_batch_size: int = 8, 74 | n_best_size: int = 5, 75 | max_answer_length: int = 300, 76 | sentence_boundary_heuristic: bool = True, 77 | full_sentence_heuristic: bool = True, 78 | shared_sentence_heuristic: bool = False, 79 | top_n_sentences: bool = True): 80 | self.do_train = do_train 81 | self.do_eval = do_eval 82 | self.do_test = do_test 83 | self.max_seq_length = max_seq_length 84 | self.doc_stride = doc_stride 85 | self.train_batch_size = train_batch_size 86 | self.gradient_accumulation_steps = gradient_accumulation_steps 87 | self.warmup_steps = warmup_steps 88 | self.learning_rate = learning_rate 89 | self.differential_lr_ratio = differential_lr_ratio 90 | self.max_grad_norm = max_grad_norm 91 | self.adam_epsilon = adam_epsilon 92 | self.num_train_epochs = num_train_epochs 93 | self.save_model = save_model 94 | self.weight_decay = weight_decay 95 | self.optimizer_class = optimizer_class 96 | self.scheduler_function = scheduler_function 97 | self.evaluate_during_training = evaluate_during_training 98 | self.eval_batch_size = eval_batch_size 99 | self.n_best_size = n_best_size 100 | self.max_answer_length = max_answer_length 101 | self.sentence_boundary_heuristic = sentence_boundary_heuristic 102 | self.full_sentence_heuristic = full_sentence_heuristic 103 | self.shared_sentence_heuristic = shared_sentence_heuristic 104 | self.top_n_sentences = top_n_sentences 105 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # Copyright 2020 Guillaume Becquin. All rights reserved. 4 | # MODIFIED FOR CAUSE EFFECT EXTRACTION 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | 19 | import json 20 | import logging 21 | import os 22 | import sys 23 | from pathlib import Path 24 | import torch 25 | from src.config import RunConfig, ModelConfigurations, model_tokenizer_mapping 26 | from src.evaluation import evaluate, predict 27 | from src.logging import initialize_log_dict 28 | from src.preprocessing import load_examples 29 | from src.training import train 30 | import argparse 31 | 32 | logging.basicConfig(level=logging.INFO) 33 | logger = logging.getLogger(__name__) 34 | 35 | if __name__ == '__main__': 36 | 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument( 39 | "--train", 40 | default=False, 41 | required=False, 42 | action="store_true", 43 | help="Flag to specify if the model should be trained", 44 | ) 45 | 46 | parser.add_argument( 47 | "--eval", 48 | default=False, 49 | required=False, 50 | action="store_true", 51 | help="Flag to specify if the model should be evaluated", 52 | ) 53 | 54 | parser.add_argument( 55 | "--test", 56 | default=False, 57 | required=False, 58 | action="store_true", 59 | help="Flag to specify if the model should generate predictions on the train file", 60 | ) 61 | args = parser.parse_args() 62 | assert args.train or args.eval or args.test, \ 63 | "At least one task needs to be selected by passing --train, --eval or --test" 64 | 65 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 66 | 67 | model_config = ModelConfigurations.RoBERTaSquadLarge 68 | run_config = RunConfig() 69 | run_config.do_train = args.train 70 | run_config.do_eval = args.eval 71 | run_config.do_test = args.test 72 | 73 | RUN_NAME = 'model_run' 74 | 75 | (MODEL_TYPE, MODEL_NAME_OR_PATH, DO_LOWER_CASE) = model_config.value 76 | 77 | fincausal_data_path = Path(os.environ.get('FINCAUSAL_DATA_PATH', 78 | os.path.dirname(os.path.realpath(sys.argv[0])) + './data')) 79 | fincausal_output_path = Path(os.environ.get('FINCAUSAL_OUTPUT_PATH', 80 | os.path.dirname(os.path.realpath(sys.argv[0])) + './output')) 81 | 82 | TRAIN_FILE = fincausal_data_path / "fnp2020-train.csv" 83 | EVAL_FILE = fincausal_data_path / "fnp2020-eval.csv" 84 | TEST_FILE = fincausal_data_path / "task2.csv" 85 | 86 | if RUN_NAME: 87 | OUTPUT_DIR = fincausal_output_path / (MODEL_NAME_OR_PATH + '_' + RUN_NAME) 88 | else: 89 | OUTPUT_DIR = fincausal_output_path / MODEL_NAME_OR_PATH 90 | 91 | model_class, tokenizer_class = model_tokenizer_mapping[MODEL_TYPE] 92 | log_file = initialize_log_dict(model_config=model_config, 93 | run_config=run_config, 94 | model_tokenizer_mapping=model_tokenizer_mapping) 95 | 96 | # Training 97 | if run_config.do_train: 98 | 99 | tokenizer = tokenizer_class.from_pretrained(MODEL_NAME_OR_PATH, 100 | do_lower_case=DO_LOWER_CASE, 101 | cache_dir=OUTPUT_DIR) 102 | model = model_class.from_pretrained(MODEL_NAME_OR_PATH).to(device) 103 | 104 | train_dataset = load_examples(file_path=TRAIN_FILE, 105 | tokenizer=tokenizer, 106 | output_examples=False, 107 | run_config=run_config) 108 | 109 | train(train_dataset=train_dataset, 110 | model=model, 111 | tokenizer=tokenizer, 112 | model_type=MODEL_TYPE, 113 | output_dir=OUTPUT_DIR, 114 | predict_file=EVAL_FILE, 115 | device=device, 116 | log_file=log_file, 117 | run_config=run_config 118 | ) 119 | if not OUTPUT_DIR.is_dir(): 120 | OUTPUT_DIR.mkdir(parents=True, exist_ok=True) 121 | if run_config.save_model: 122 | model_to_save = model.module if hasattr(model, "module") else model 123 | model_to_save.save_pretrained(OUTPUT_DIR) 124 | tokenizer.save_pretrained(OUTPUT_DIR) 125 | logger.info("Saving final model to %s", OUTPUT_DIR) 126 | logger.info("Saving log file to %s", OUTPUT_DIR) 127 | with open(os.path.join(OUTPUT_DIR, "logs.json"), 'w') as f: 128 | json.dump(log_file, f, indent=4) 129 | 130 | if run_config.do_eval: 131 | tokenizer = tokenizer_class.from_pretrained(str(OUTPUT_DIR), do_lower_case=DO_LOWER_CASE) 132 | model = model_class.from_pretrained(str(OUTPUT_DIR)).to(device) 133 | 134 | result = evaluate(model=model, 135 | tokenizer=tokenizer, 136 | device=device, 137 | file_path=EVAL_FILE, 138 | model_type=MODEL_TYPE, 139 | output_dir=OUTPUT_DIR, 140 | run_config=run_config 141 | ) 142 | 143 | print("done") 144 | 145 | if run_config.do_test: 146 | tokenizer = tokenizer_class.from_pretrained(str(OUTPUT_DIR), do_lower_case=DO_LOWER_CASE) 147 | model = model_class.from_pretrained(str(OUTPUT_DIR)).to(device) 148 | 149 | result = predict(model=model, 150 | tokenizer=tokenizer, 151 | device=device, 152 | file_path=TEST_FILE, 153 | model_type=MODEL_TYPE, 154 | output_dir=OUTPUT_DIR, 155 | run_config=run_config 156 | ) 157 | 158 | print("done") 159 | -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # Copyright 2020 Guillaume Becquin. 4 | # MODIFIED FOR CAUSE EFFECT EXTRACTION 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import unicodedata 19 | from typing import Optional, List 20 | 21 | import pandas as pd 22 | 23 | 24 | class FinCausalExample: 25 | 26 | def __init__( 27 | self, 28 | example_id: str, 29 | context_text: str, 30 | cause_text: str, 31 | effect_text: str, 32 | offset_sentence_2: int, 33 | offset_sentence_3: int, 34 | cause_start_position_character: Optional[int], 35 | cause_end_position_character: Optional[int], 36 | effect_start_position_character: Optional[int], 37 | effect_end_position_character: Optional[int] 38 | ): 39 | 40 | self.example_id = example_id 41 | self.context_text = context_text 42 | self.cause_text = cause_text 43 | self.effect_text = effect_text 44 | 45 | self.start_cause_position, self.end_cause_position = 0, 0 46 | self.start_effect_position, self.end_effect_position = 0, 0 47 | 48 | doc_tokens: List[str] = [] 49 | char_to_word_offset: List[int] = [] 50 | word_to_char_mapping: List[int] = [] 51 | prev_is_whitespace = True 52 | 53 | # Split on whitespace so that different tokens may be attributed to their original position. 54 | for char_index, c in enumerate(self.context_text): 55 | if _is_whitespace(c): 56 | prev_is_whitespace = True 57 | else: 58 | if prev_is_whitespace or _is_punctuation(c): 59 | doc_tokens.append(c) 60 | word_to_char_mapping.append(char_index) 61 | else: 62 | doc_tokens[-1] += c 63 | if _is_punctuation(c): 64 | prev_is_whitespace = True 65 | else: 66 | prev_is_whitespace = False 67 | char_to_word_offset.append(len(doc_tokens) - 1) 68 | 69 | self.doc_tokens = doc_tokens 70 | self.char_to_word_offset = char_to_word_offset 71 | 72 | # Start and end positions only has a value during evaluation. 73 | if cause_start_position_character is not None: 74 | assert (cause_start_position_character + len(cause_text) == cause_end_position_character) 75 | self.start_cause_position = char_to_word_offset[cause_start_position_character] 76 | self.end_cause_position = char_to_word_offset[ 77 | min(cause_start_position_character + len(cause_text) - 1, len(char_to_word_offset) - 1) 78 | ] 79 | if effect_start_position_character is not None: 80 | self.start_effect_position = char_to_word_offset[effect_start_position_character] 81 | assert (effect_start_position_character + len(effect_text) == effect_end_position_character) 82 | self.end_effect_position = char_to_word_offset[ 83 | min(effect_start_position_character + len(effect_text) - 1, len(char_to_word_offset) - 1) 84 | ] 85 | 86 | if pd.notna(offset_sentence_2): 87 | self.offset_sentence_2 = char_to_word_offset[int(offset_sentence_2)] 88 | else: 89 | self.offset_sentence_2 = 0 90 | if pd.notna(offset_sentence_3): 91 | self.offset_sentence_3 = char_to_word_offset[int(offset_sentence_3)] 92 | else: 93 | self.offset_sentence_3 = 0 94 | 95 | self.word_to_char_mapping = word_to_char_mapping 96 | 97 | 98 | class FinCausalFeatures: 99 | 100 | def __init__( 101 | self, 102 | input_ids, 103 | attention_mask, 104 | token_type_ids, 105 | cls_index, 106 | p_mask, 107 | example_orig_index, 108 | example_index, 109 | unique_id, 110 | paragraph_len, 111 | token_is_max_context, 112 | tokens, 113 | token_to_orig_map, 114 | cause_start_position, 115 | cause_end_position, 116 | effect_start_position, 117 | effect_end_position, 118 | sentence_2_offset, 119 | sentence_3_offset, 120 | is_impossible, 121 | ): 122 | self.input_ids = input_ids 123 | self.attention_mask = attention_mask 124 | self.token_type_ids = token_type_ids 125 | self.cls_index = cls_index 126 | self.p_mask = p_mask 127 | 128 | self.example_orig_index = example_orig_index 129 | self.example_index = example_index 130 | self.unique_id = unique_id 131 | self.paragraph_len = paragraph_len 132 | self.token_is_max_context = token_is_max_context 133 | self.tokens = tokens 134 | self.token_to_orig_map = token_to_orig_map 135 | 136 | self.cause_start_position = cause_start_position 137 | self.cause_end_position = cause_end_position 138 | self.effect_start_position = effect_start_position 139 | self.effect_end_position = effect_end_position 140 | 141 | self.sentence_2_offset = sentence_2_offset 142 | self.sentence_3_offset = sentence_3_offset 143 | 144 | self.is_impossible = is_impossible 145 | 146 | 147 | class FinCausalResult: 148 | def __init__(self, unique_id, start_cause_logits, end_cause_logits, start_effect_logits, end_effect_logits, 149 | start_top_index=None, end_top_index=None, cls_logits=None): 150 | self.start_cause_logits = start_cause_logits 151 | self.end_cause_logits = end_cause_logits 152 | self.start_effect_logits = start_effect_logits 153 | self.end_effect_logits = end_effect_logits 154 | self.unique_id = unique_id 155 | 156 | if start_top_index: 157 | self.start_top_index = start_top_index 158 | self.end_top_index = end_top_index 159 | self.cls_logits = cls_logits 160 | 161 | 162 | def _is_whitespace(char: str) -> bool: 163 | if char == " " or char == "\t" or char == "\r" or char == "\n" or char == '\xa0' or ord(char) == 0x202F: 164 | return True 165 | return False 166 | 167 | 168 | def _is_punctuation(char: str) -> bool: 169 | cp = ord(char) 170 | if (33 <= cp <= 47) or (58 <= cp <= 64) or (91 <= cp <= 96) or (123 <= cp <= 126): 171 | return True 172 | cat = unicodedata.category(char) 173 | if cat.startswith("P"): 174 | return True 175 | return False 176 | -------------------------------------------------------------------------------- /src/training.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # Copyright 2020 Guillaume Becquin. 4 | # MODIFIED FOR CAUSE EFFECT EXTRACTION 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import logging 19 | from pathlib import Path 20 | from typing import Dict, Union 21 | 22 | import torch 23 | from torch.nn import Module 24 | from torch.utils.data import RandomSampler, DataLoader, TensorDataset 25 | from tqdm import trange, tqdm 26 | from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast 27 | 28 | from .config import RunConfig 29 | from .evaluation import evaluate 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | def train(train_dataset: TensorDataset, 35 | model: Module, 36 | tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], 37 | model_type: str, 38 | output_dir: Path, 39 | predict_file: Path, 40 | log_file: Dict, 41 | device: torch.device, 42 | run_config: RunConfig): 43 | train_sampler = RandomSampler(train_dataset) 44 | train_dataloader = DataLoader(train_dataset, 45 | sampler=train_sampler, 46 | batch_size=run_config.train_batch_size) 47 | 48 | t_total = len(train_dataloader) // run_config.gradient_accumulation_steps * run_config.num_train_epochs 49 | 50 | # Define Optimizer and learning rates / decay 51 | no_decay = ["bias", "LayerNorm.weight"] 52 | no_scaled_lr = ["cause_outputs", "effect_outputs"] 53 | if run_config.differential_lr_ratio == 0: 54 | differential_lr_ratio = 1.0 55 | else: 56 | differential_lr_ratio = run_config.differential_lr_ratio 57 | assert differential_lr_ratio <= 1, "ratio for language model layers should be <= 1" 58 | optimizer_grouped_parameters = [ 59 | { 60 | "params": [p for n, p in model.named_parameters() if (not any(nd in n for nd in no_decay) 61 | and not any(nlr in n for nlr in no_scaled_lr))], 62 | 'lr': run_config.learning_rate * differential_lr_ratio, 63 | "weight_decay": run_config.weight_decay, 64 | }, 65 | { 66 | "params": [p for n, p in model.named_parameters() if (not any(nd in n for nd in no_decay) 67 | and any(nlr in n for nlr in no_scaled_lr))], 68 | 'lr': run_config.learning_rate, 69 | "weight_decay": run_config.weight_decay, 70 | }, 71 | { 72 | "params": [p for n, p in model.named_parameters() if (any(nd in n for nd in no_decay) 73 | and not any(nlr in n for nlr in no_scaled_lr))], 74 | 'lr': run_config.learning_rate * differential_lr_ratio, 75 | "weight_decay": 0.0 76 | }, 77 | { 78 | "params": [p for n, p in model.named_parameters() if (any(nd in n for nd in no_decay) 79 | and any(nlr in n for nlr in no_scaled_lr))], 80 | 'lr': run_config.learning_rate, 81 | "weight_decay": 0.0 82 | }, 83 | ] 84 | optimizer = run_config.optimizer_class(optimizer_grouped_parameters, 85 | lr=run_config.learning_rate, 86 | eps=run_config.adam_epsilon) 87 | 88 | # Define Scheduler 89 | try: 90 | scheduler = run_config.scheduler_function(optimizer, 91 | num_warmup_steps=run_config.warmup_steps, 92 | num_training_steps=t_total) 93 | except ValueError: 94 | scheduler = run_config.scheduler_function(optimizer, 95 | num_warmup_steps=run_config.warmup_steps) 96 | 97 | # Start training 98 | logger.info("***** Running training *****") 99 | logger.info(" Num examples = %d", len(train_dataset)) 100 | logger.info(" Num Epochs = %d", run_config.num_train_epochs) 101 | logger.info( 102 | " Total train batch size (w. parallel, distributed & accumulation) = %d", 103 | run_config.train_batch_size * run_config.gradient_accumulation_steps 104 | ) 105 | logger.info(" Gradient Accumulation steps = %d", run_config.gradient_accumulation_steps) 106 | logger.info(" Total optimization steps = %d", t_total) 107 | 108 | global_step = 1 109 | epochs_trained = 0 110 | 111 | tr_loss, logging_loss = 0.0, 0.0 112 | model.zero_grad() 113 | train_iterator = trange(epochs_trained, int(run_config.num_train_epochs), desc="Epoch") 114 | 115 | for _ in train_iterator: 116 | epoch_iterator = tqdm(train_dataloader, desc=f"Iteration Loss: {tr_loss / global_step}", position=0, leave=True) 117 | for step, batch in enumerate(epoch_iterator): 118 | epoch_iterator.set_description(f"Iteration Loss: {tr_loss / global_step}") 119 | 120 | model.train() 121 | batch = tuple(t.to(device) for t in batch) 122 | 123 | inputs = { 124 | "input_ids": batch[0], 125 | "attention_mask": batch[1], 126 | "token_type_ids": batch[2], 127 | "start_cause_positions": batch[3], 128 | "end_cause_positions": batch[4], 129 | "start_effect_positions": batch[5], 130 | "end_effect_positions": batch[6], 131 | } 132 | 133 | if model_type in ["xlm", "roberta", "distilbert", "camembert"]: 134 | del inputs["token_type_ids"] 135 | 136 | outputs = model(**inputs) 137 | loss = outputs[0] 138 | 139 | if run_config.gradient_accumulation_steps > 1: 140 | loss = loss / run_config.gradient_accumulation_steps 141 | 142 | loss.backward() 143 | 144 | tr_loss += loss.item() 145 | if (step + 1) % run_config.gradient_accumulation_steps == 0: 146 | torch.nn.utils.clip_grad_norm_(model.parameters(), run_config.max_grad_norm) 147 | optimizer.step() 148 | scheduler.step() 149 | model.zero_grad() 150 | global_step += 1 151 | 152 | # Evaluate model and log metrics 153 | if run_config.evaluate_during_training: 154 | metrics = evaluate(model=model, 155 | tokenizer=tokenizer, 156 | device=device, 157 | file_path=predict_file, 158 | model_type=model_type, 159 | output_dir=output_dir, 160 | run_config=run_config) 161 | log_file[f'step_{global_step}'] = metrics 162 | 163 | _output_dir = output_dir / "checkpoint-{}".format(global_step) 164 | if not _output_dir.is_dir(): 165 | _output_dir.mkdir(parents=True, exist_ok=True) 166 | 167 | model.save_pretrained(_output_dir) 168 | tokenizer.save_pretrained(_output_dir) 169 | logger.info("Best F1 score: saving model checkpoint to %s", _output_dir) 170 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # Copyright 2020, Guillaume Becquin 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import logging 18 | import multiprocessing 19 | from collections import UserDict 20 | from functools import partial 21 | from pathlib import Path 22 | from typing import Union, List, Tuple 23 | 24 | import spacy 25 | from pysbd.utils import PySBDFactory 26 | 27 | import numpy as np 28 | import pandas as pd 29 | import torch 30 | from torch.utils.data import TensorDataset 31 | from tqdm import tqdm 32 | from transformers import PreTrainedTokenizer, BatchEncoding, PreTrainedTokenizerFast 33 | from transformers.tokenization_bert import whitespace_tokenize 34 | 35 | from .config import RunConfig 36 | from .data import FinCausalExample, FinCausalFeatures, _is_punctuation 37 | 38 | logger = logging.getLogger(__name__) 39 | 40 | 41 | def load_examples(file_path: Path, 42 | tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], 43 | run_config: RunConfig, 44 | output_examples: bool = True, 45 | evaluate: bool = False) -> \ 46 | Union[Tuple[TensorDataset, List[FinCausalExample], List[FinCausalFeatures]], 47 | TensorDataset]: 48 | processor = FinCausalProcessor() 49 | examples = processor.get_examples(file_path) 50 | 51 | features, dataset = fincausal_convert_examples_to_features( 52 | examples=examples, 53 | tokenizer=tokenizer, 54 | max_seq_length=run_config.max_seq_length, 55 | doc_stride=run_config.doc_stride, 56 | is_training=not evaluate, 57 | return_dataset="pt", 58 | threads=1, 59 | ) 60 | 61 | if output_examples: 62 | return dataset, examples, features 63 | return dataset 64 | 65 | 66 | class FinCausalProcessor: 67 | 68 | def get_examples(self, file_path: Path) -> List[FinCausalExample]: 69 | input_data = pd.read_csv(file_path, index_col=0, delimiter=';', header=0, skipinitialspace=True) 70 | input_data.columns = [col_name.strip() for col_name in input_data.columns] 71 | return self._create_examples(input_data) 72 | 73 | @staticmethod 74 | def _create_examples(input_data: pd.DataFrame) -> List[FinCausalExample]: 75 | 76 | nlp = spacy.blank('en') 77 | nlp.add_pipe(PySBDFactory(nlp)) 78 | 79 | examples = [] 80 | for entry in tqdm(input_data.itertuples()): 81 | context_text = entry.Text 82 | example_id = entry.Index 83 | if len(entry) > 2: 84 | cause_text = entry.Cause 85 | effect_text = entry.Effect 86 | cause_start_position_character = entry.Cause_Start 87 | cause_end_position_character = entry.Cause_End 88 | effect_start_position_character = entry.Effect_Start 89 | effect_end_position_character = entry.Effect_End 90 | else: 91 | cause_text = "" 92 | effect_text = "" 93 | cause_start_position_character = 0 94 | cause_end_position_character = 0 95 | effect_start_position_character = 0 96 | effect_end_position_character = 0 97 | 98 | doc = nlp(entry.Text) 99 | sentences = list(doc.sents) 100 | offset_sentence_2 = np.nan 101 | offset_sentence_3 = np.nan 102 | if len(sentences) > 1: 103 | offset_sentence_2 = sentences[0].end_char 104 | if len(sentences) > 2: 105 | offset_sentence_3 = sentences[1].end_char 106 | 107 | example = FinCausalExample( 108 | example_id, 109 | context_text, 110 | cause_text, 111 | effect_text, 112 | offset_sentence_2, 113 | offset_sentence_3, 114 | cause_start_position_character, 115 | cause_end_position_character, 116 | effect_start_position_character, 117 | effect_end_position_character 118 | ) 119 | examples.append(example) 120 | return examples 121 | 122 | 123 | def fincausal_convert_example_to_features_init( 124 | tokenizer_for_convert: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]): 125 | global tokenizer 126 | tokenizer = tokenizer_for_convert 127 | 128 | 129 | def _improve_answer_span(doc_tokens: List[str], 130 | input_start: int, 131 | input_end: int, 132 | tokenizer, 133 | orig_answer_text: str) -> Tuple[int, int]: 134 | """Returns tokenized answer spans that better match the annotated answer.""" 135 | tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text)) 136 | 137 | for new_start in range(input_start, input_end + 1): 138 | for new_end in range(input_end, new_start - 1, -1): 139 | text_span = " ".join(doc_tokens[new_start: (new_end + 1)]) 140 | if text_span == tok_answer_text: 141 | return new_start, new_end 142 | 143 | return input_start, input_end 144 | 145 | 146 | def _check_is_max_context(doc_spans: Union[List[dict], List[UserDict]], 147 | cur_span_index: int, 148 | position: int) -> bool: 149 | """Check if this is the 'max context' doc span for the token.""" 150 | best_score = None 151 | best_span_index = None 152 | for (span_index, doc_span) in enumerate(doc_spans): 153 | end = doc_span["start"] + doc_span["length"] - 1 154 | if position < doc_span["start"]: 155 | continue 156 | if position > end: 157 | continue 158 | num_left_context = position - doc_span["start"] 159 | num_right_context = end - position 160 | score = min(num_left_context, num_right_context) + 0.01 * doc_span["length"] 161 | if best_score is None or score > best_score: 162 | best_score = score 163 | best_span_index = span_index 164 | 165 | return cur_span_index == best_span_index 166 | 167 | 168 | def fincausal_convert_example_to_features(example: FinCausalExample, 169 | max_seq_length: int, 170 | doc_stride: int, 171 | is_training: bool) -> List[FinCausalFeatures]: 172 | features = [] 173 | if is_training: 174 | # Get start and end position 175 | start_cause_position = example.start_cause_position 176 | end_cause_position = example.end_cause_position 177 | start_effect_position = example.start_effect_position 178 | end_effect_position = example.end_effect_position 179 | 180 | # If the cause cannot be found in the text, then skip this example. 181 | actual_cause_text = " ".join(example.doc_tokens[start_cause_position: (end_cause_position + 1)]) 182 | cleaned_cause_text = " ".join(whitespace_tokenize(_run_split_on_punc(example.cause_text))) 183 | if actual_cause_text.find(cleaned_cause_text) == -1: 184 | logger.warning("Could not find cause: '%s' vs. '%s'", actual_cause_text, cleaned_cause_text) 185 | return [] 186 | 187 | # If the effect cannot be found in the text, then skip this example. 188 | actual_effect_text = " ".join(example.doc_tokens[start_effect_position: (end_effect_position + 1)]) 189 | cleaned_effect_text = " ".join(whitespace_tokenize(_run_split_on_punc(example.effect_text))) 190 | if actual_effect_text.find(cleaned_effect_text) == -1: 191 | logger.warning("Could not find effect: '%s' vs. '%s'", actual_effect_text, cleaned_effect_text) 192 | return [] 193 | 194 | tok_to_orig_index = [] 195 | orig_to_tok_index = [] 196 | all_doc_tokens = [] 197 | for (i, token) in enumerate(example.doc_tokens): 198 | orig_to_tok_index.append(len(all_doc_tokens)) 199 | sub_tokens = tokenizer.tokenize(token) 200 | for sub_token in sub_tokens: 201 | tok_to_orig_index.append(i) 202 | all_doc_tokens.append(sub_token) 203 | 204 | if is_training: 205 | tok_cause_start_position = orig_to_tok_index[example.start_cause_position] 206 | if example.end_cause_position < len(example.doc_tokens) - 1: 207 | tok_cause_end_position = orig_to_tok_index[example.end_cause_position + 1] - 1 208 | else: 209 | tok_cause_end_position = len(all_doc_tokens) - 1 210 | 211 | (tok_cause_start_position, tok_cause_end_position) = _improve_answer_span( 212 | all_doc_tokens, tok_cause_start_position, tok_cause_end_position, tokenizer, example.cause_text 213 | ) 214 | 215 | tok_effect_start_position = orig_to_tok_index[example.start_effect_position] 216 | if example.end_effect_position < len(example.doc_tokens) - 1: 217 | tok_effect_end_position = orig_to_tok_index[example.end_effect_position + 1] - 1 218 | else: 219 | tok_effect_end_position = len(all_doc_tokens) - 1 220 | 221 | (tok_effect_start_position, tok_effect_end_position) = _improve_answer_span( 222 | all_doc_tokens, tok_effect_start_position, tok_effect_end_position, tokenizer, example.effect_text 223 | ) 224 | if example.offset_sentence_2 > 0: 225 | tok_sentence_2_offset = orig_to_tok_index[example.offset_sentence_2 + 1] - 1 226 | else: 227 | tok_sentence_2_offset = None 228 | if example.offset_sentence_3 > 0: 229 | tok_sentence_3_offset = orig_to_tok_index[example.offset_sentence_3 + 1] - 1 230 | else: 231 | tok_sentence_3_offset = None 232 | 233 | spans: List[BatchEncoding] = [] 234 | 235 | sequence_added_tokens = tokenizer.max_len - tokenizer.max_len_single_sentence 236 | 237 | span_doc_tokens = all_doc_tokens 238 | while len(spans) * doc_stride < len(all_doc_tokens): 239 | 240 | encoded_dict: BatchEncoding = tokenizer.encode_plus(span_doc_tokens, 241 | max_length=max_seq_length, 242 | return_overflowing_tokens=True, 243 | pad_to_max_length=True, 244 | stride=max_seq_length - doc_stride - sequence_added_tokens - 1, 245 | truncation_strategy="only_first", 246 | truncation=True, 247 | return_token_type_ids=True, 248 | ) 249 | 250 | paragraph_len = min( 251 | len(all_doc_tokens) - len(spans) * doc_stride, 252 | max_seq_length - sequence_added_tokens, 253 | ) 254 | if tokenizer.pad_token_id in encoded_dict["input_ids"]: 255 | if tokenizer.padding_side == "right": 256 | non_padded_ids = encoded_dict.data["input_ids"][ 257 | : encoded_dict.data["input_ids"].index(tokenizer.pad_token_id)] 258 | else: 259 | last_padding_id_position = ( 260 | len(encoded_dict.data["input_ids"]) 261 | - 1 262 | - encoded_dict["input_ids"][::-1].index(tokenizer.pad_token_id) 263 | ) 264 | non_padded_ids = encoded_dict["input_ids"][last_padding_id_position + 1:] 265 | else: 266 | non_padded_ids = encoded_dict["input_ids"] 267 | 268 | tokens = tokenizer.convert_ids_to_tokens(non_padded_ids) 269 | 270 | token_to_orig_map = {} 271 | for i in range(paragraph_len): 272 | index = sequence_added_tokens + i 273 | token_to_orig_map[index] = tok_to_orig_index[len(spans) * doc_stride + i] 274 | 275 | encoded_dict["paragraph_len"] = paragraph_len 276 | encoded_dict["tokens"] = tokens 277 | encoded_dict["token_to_orig_map"] = token_to_orig_map 278 | encoded_dict["token_is_max_context"] = {} 279 | encoded_dict["start"] = len(spans) * doc_stride 280 | encoded_dict["length"] = paragraph_len 281 | 282 | spans.append(encoded_dict) 283 | 284 | if len(encoded_dict.get("overflowing_tokens", [])) == 0: 285 | break 286 | span_doc_tokens = encoded_dict["overflowing_tokens"] 287 | 288 | for doc_span_index in range(len(spans)): 289 | for j in range(spans[doc_span_index].data["paragraph_len"]): 290 | is_max_context = _check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j) 291 | spans[doc_span_index].data["token_is_max_context"][j] = is_max_context 292 | 293 | for span in spans: 294 | # Identify the position of the CLS token 295 | cls_index = span.data["input_ids"].index(tokenizer.cls_token_id) 296 | 297 | p_mask = np.ones(len(span.data["token_type_ids"])) 298 | p_mask[np.where(np.array(span.data["input_ids"]) == tokenizer.sep_token_id)[0]] = 1 299 | # Set the CLS index to '0' 300 | p_mask[cls_index] = 0 301 | 302 | span_is_impossible = False 303 | cause_start_position = 0 304 | cause_end_position = 0 305 | effect_start_position = 0 306 | effect_end_position = 0 307 | doc_start = span.data["start"] 308 | doc_end = span.data["start"] + span.data["length"] - 1 309 | out_of_span = False 310 | if tokenizer.padding_side == "left": 311 | doc_offset = 0 312 | else: 313 | doc_offset = sequence_added_tokens 314 | if tok_sentence_2_offset is not None: 315 | sentence_2_offset = tok_sentence_2_offset - doc_start + doc_offset 316 | else: 317 | sentence_2_offset = None 318 | if tok_sentence_3_offset is not None: 319 | sentence_3_offset = tok_sentence_3_offset - doc_start + doc_offset 320 | else: 321 | sentence_3_offset = None 322 | if is_training: 323 | # For training, if our document chunk does not contain an annotation 324 | # we throw it out, since there is nothing to predict. 325 | if not (tok_cause_start_position >= doc_start 326 | and tok_cause_end_position <= doc_end 327 | and tok_effect_start_position >= doc_start 328 | and tok_effect_end_position <= doc_end): 329 | out_of_span = True 330 | 331 | if out_of_span: 332 | cause_start_position = cls_index 333 | cause_end_position = cls_index 334 | effect_start_position = cls_index 335 | effect_end_position = cls_index 336 | span_is_impossible = True 337 | else: 338 | cause_start_position = tok_cause_start_position - doc_start + doc_offset 339 | cause_end_position = tok_cause_end_position - doc_start + doc_offset 340 | effect_start_position = tok_effect_start_position - doc_start + doc_offset 341 | effect_end_position = tok_effect_end_position - doc_start + doc_offset 342 | 343 | features.append( 344 | FinCausalFeatures( 345 | span["input_ids"], 346 | span["attention_mask"], 347 | span["token_type_ids"], 348 | cls_index, 349 | p_mask.tolist(), 350 | example_orig_index=example.example_id, 351 | example_index=0, 352 | unique_id=0, 353 | paragraph_len=span["paragraph_len"], 354 | token_is_max_context=span["token_is_max_context"], 355 | tokens=span["tokens"], 356 | token_to_orig_map=span["token_to_orig_map"], 357 | cause_start_position=cause_start_position, 358 | cause_end_position=cause_end_position, 359 | effect_start_position=effect_start_position, 360 | effect_end_position=effect_end_position, 361 | sentence_2_offset=sentence_2_offset, 362 | sentence_3_offset=sentence_3_offset, 363 | is_impossible=span_is_impossible, 364 | ) 365 | ) 366 | return features 367 | 368 | 369 | def fincausal_convert_examples_to_features(examples: List[FinCausalExample], 370 | tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], 371 | max_seq_length: int, 372 | doc_stride: int, 373 | is_training: bool, return_dataset: Union[bool, str] = False, 374 | threads: int = 1) -> Union[List[FinCausalFeatures], 375 | Tuple[List[FinCausalFeatures], TensorDataset]]: 376 | with multiprocessing.Pool(threads, 377 | initializer=fincausal_convert_example_to_features_init, 378 | initargs=(tokenizer,)) as p: 379 | annotate_ = partial( 380 | fincausal_convert_example_to_features, 381 | max_seq_length=max_seq_length, 382 | doc_stride=doc_stride, 383 | is_training=is_training, 384 | ) 385 | features: List[FinCausalFeatures] = list( 386 | tqdm( 387 | p.imap(annotate_, examples, chunksize=32), 388 | total=len(examples), 389 | desc="convert squad examples to features", 390 | position=0, 391 | leave=True 392 | ) 393 | ) 394 | new_features = [] 395 | unique_id = 1000000000 396 | example_index = 0 397 | for example_features in tqdm(features, total=len(features), desc="add example index and unique id", 398 | position=0, leave=True): 399 | if not example_features: 400 | continue 401 | for example_feature in example_features: 402 | example_feature.example_index = example_index 403 | example_feature.unique_id = unique_id 404 | new_features.append(example_feature) 405 | unique_id += 1 406 | example_index += 1 407 | features = new_features 408 | del new_features 409 | if return_dataset == "pt": 410 | 411 | # Convert to Tensors and build dataset 412 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 413 | all_attention_masks = torch.tensor([f.attention_mask for f in features], dtype=torch.long) 414 | all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) 415 | all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long) 416 | all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float) 417 | all_is_impossible = torch.tensor([f.is_impossible for f in features], dtype=torch.float) 418 | 419 | if not is_training: 420 | all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) 421 | dataset = TensorDataset( 422 | all_input_ids, all_attention_masks, all_token_type_ids, all_example_index, all_cls_index, all_p_mask 423 | ) 424 | else: 425 | all_cause_start_positions = torch.tensor([f.cause_start_position for f in features], dtype=torch.long) 426 | all_cause_end_positions = torch.tensor([f.cause_end_position for f in features], dtype=torch.long) 427 | all_effect_start_positions = torch.tensor([f.effect_start_position for f in features], dtype=torch.long) 428 | all_effect_end_positions = torch.tensor([f.effect_end_position for f in features], dtype=torch.long) 429 | dataset = TensorDataset( 430 | all_input_ids, 431 | all_attention_masks, 432 | all_token_type_ids, 433 | all_cause_start_positions, 434 | all_cause_end_positions, 435 | all_effect_start_positions, 436 | all_effect_end_positions, 437 | all_cls_index, 438 | all_p_mask, 439 | all_is_impossible, 440 | ) 441 | 442 | return features, dataset 443 | return features 444 | 445 | 446 | def _run_split_on_punc(text: str) -> str: 447 | chars = list(text) 448 | i = 0 449 | start_new_word = True 450 | output = [] 451 | while i < len(chars): 452 | char = chars[i] 453 | if _is_punctuation(char): 454 | output.append([char]) 455 | start_new_word = True 456 | else: 457 | if start_new_word: 458 | output.append([]) 459 | start_new_word = False 460 | output[-1].append(char) 461 | i += 1 462 | 463 | return " ".join(["".join(x) for x in output]) 464 | -------------------------------------------------------------------------------- /src/fincausal_evaluation/task2_evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyrights YseopLab 2020. All rights reserved. 4 | # Original evaluation file from https://github.com/yseop/YseopLab/blob/develop/FNP_2020_FinCausal/scoring/task2/task2_evaluate.py 5 | """ task2_evaluate.py - Scoring program for Fincausal 2020 Task 2 6 | 7 | usage: task2_evaluate.py [-h] {from-folder,from-file} ... 8 | 9 | positional arguments: 10 | {from-folder,from-file} 11 | Use from-file for basic mode or from-folder for 12 | Codalab compatible mode 13 | 14 | Usage 1: Folder mode 15 | 16 | usage: task2_evaluate.py from-folder [-h] input output 17 | 18 | Codalab mode with input and output folders 19 | 20 | positional arguments: 21 | input input folder with ref (reference) and res (result) sub folders 22 | output output folder where score.txt is written 23 | 24 | optional arguments: 25 | -h, --help show this help message and exit 26 | task2_evaluate input output 27 | 28 | input, output folders must follow the Codalab competition convention for scoring bundle 29 | e.g. 30 | ├───input 31 | │ ├───ref 32 | │ └───res 33 | └───output 34 | 35 | Usage 2: File mode 36 | 37 | usage: task2_evaluate.py from-file [-h] [--ref_file REF_FILE] pred_file [score_file] 38 | 39 | Basic mode with path to input and output files 40 | 41 | positional arguments: 42 | ref_file reference file (default: ../../data/fnp2020-fincausal-task2.csv) 43 | pred_file prediction file to evaluate 44 | score_file path to output score file (or stdout if not provided) 45 | 46 | optional arguments: 47 | -h, --help show this help message and exit 48 | """ 49 | import argparse 50 | import logging 51 | import os 52 | import unittest 53 | 54 | from collections import namedtuple 55 | 56 | import nltk 57 | 58 | from sklearn import metrics 59 | 60 | 61 | def build_token_index(text): 62 | """ 63 | build a dictionary of all tokenized items from text with their respective positions in the text. 64 | E.g. "this is a basic example of a basic method" returns 65 | {'this': [0], 'is': [1], 'a': [2, 6], 'basic': [3, 7], 'example': [4], 'of': [5], 'method': [8]} 66 | :param text: reference text to index 67 | :return: dict() of text token, each token with their respective position(s) in the text 68 | """ 69 | tokens = nltk.word_tokenize(text) 70 | token_index = {} 71 | for position, token in enumerate(tokens): 72 | if token in token_index: 73 | token_index[token].append(position) 74 | else: 75 | token_index[token] = [position] 76 | return tokens, token_index 77 | 78 | 79 | def get_tokens_sequence(text, token_index): 80 | tokens = nltk.word_tokenize(text) 81 | # build list of possible position for each token 82 | # positions = [word_index[word] for word in words] 83 | positions = [] 84 | for token in tokens: 85 | if token in token_index: 86 | positions.append(token_index[token]) 87 | continue 88 | # Special case when '.' is not tokenized properly 89 | alt_token = ''.join([token, '.']) 90 | if alt_token in token_index: 91 | logging.debug(f'tokenize fix ".": {alt_token}') 92 | positions.append(token_index[alt_token]) 93 | # TODO: discard the next token if == '.' ? 94 | continue 95 | # Special case when/if ',' is not tokenized properly - TBC 96 | # alt_token = ''.join([token, ',']) 97 | # if alt_token in token_index: 98 | # logging.debug(f'tokenize fix ",": {alt_token}') 99 | # positions.append(token_index[alt_token]) 100 | # continue 101 | else: 102 | logging.warning(f'get_tokens_sequence "{token}" discarded') 103 | # No matching ? stop here 104 | if len(positions) == 0: 105 | return positions 106 | # recursively process the list of token positions to return combinations of consecutive tokens 107 | seqs = _get_sequences(*positions) 108 | # Note: several sequences can possibly be found in the reference text, when similar text patterns are repeated 109 | # always return the longest 110 | return max(seqs, key=len) 111 | 112 | 113 | def _get_sequences(*args, value=None, path=None): 114 | """ 115 | Recursive method to select sequences of successive tokens using their position relative to the reference text. 116 | A sequence is the list of successive indexes in the tokenized reference text. 117 | Implemented as a product() of successive positions constrained by their 118 | :param args: list of list of positions 119 | :param value: position of the previous token (i.e. next token position must be in range [value+1, value+3] 120 | :param path: debugging - current sequence 121 | :return: 122 | """ 123 | logging.debug(path) 124 | # end of recursion 125 | if len(args) == 1: 126 | if value is not None: 127 | # return items matching constraint (i.e. within range with previous token) 128 | return [x for x in args[0] if x > value and (x < value + 3)] 129 | else: 130 | # Special case where text is restricted to a single token 131 | # return all positions on first call (i.e. value is None) 132 | return [args[0]] 133 | else: 134 | # iterate over current token possible positions and combine with other tokens from recursive call 135 | # result is a list of explored sequences (i.e. list of list of positions) 136 | result = [] 137 | for x in args[0]: 138 | # keep track of current explored sequence 139 | p = [x] if path is None else list(path + [x]) 140 | # 141 | if value is None or (x > value and (x < value + 3)): 142 | seqs = _get_sequences(*args[1:], value=x, path=p) 143 | # when recursion returns empty list and current position match constraint (either only value 144 | # or value within range) add current position as a single result 145 | if len(seqs) == 0 and (value is None or (x > value and (x < value + 3))): 146 | result.append([x]) 147 | else: 148 | # otherwise combine current position with recursion results (whether returned sequences are list 149 | # or single number) and add to the list of results for this token position 150 | for s in seqs: 151 | res = [x] + s if type(s) is list else [x, s] 152 | result.append(res) 153 | return result 154 | 155 | 156 | def encode_causal_tokens(text, cause, effect): 157 | """ 158 | Encode text, cause and effect into a single list with each token represented by their respective 159 | class labels ('-','C','E') 160 | :param text: reference text 161 | :param cause: causal substring in reference text 162 | :param effect: effect substring in reference text 163 | :return: text string converted as a list of tuple(token, label) 164 | """ 165 | # Get reference text tokens and token index 166 | logging.debug(f'Reference: {text}') 167 | words, wi = build_token_index(text) 168 | logging.debug(f'Token index: {wi}') 169 | 170 | # init labels with default class label 171 | labels = ['-' for _ in range(len(words))] 172 | 173 | # encode cause using token index 174 | logging.debug(f'Cause: {cause}') 175 | cause_seq = get_tokens_sequence(cause, wi) 176 | logging.debug(f'Cause seq.: {cause_seq}') 177 | for position in cause_seq: 178 | labels[position] = 'C' 179 | 180 | # encode effect using token index 181 | logging.debug(f'Effect: {effect}') 182 | effect_seq = get_tokens_sequence(effect, wi) 183 | logging.debug(f'Effect seq.: {effect_seq}') 184 | for position in effect_seq: 185 | labels[position] = 'E' 186 | 187 | logging.debug(labels) 188 | 189 | return zip(words, labels) 190 | 191 | 192 | def evaluate(truth, predict, classes): 193 | """ 194 | Fincausal 2020 Task 2 evaluation: returns precision, recall and F1 comparing submitting data to reference data. 195 | :param truth: list of Task2Data(index, text, cause, effect, labels) - reference data set 196 | :param predict: list of Task2Data(index, text, cause, effect, labels) - submission data set 197 | :param classes: list of classes 198 | :return: tuple(precision, recall, f1, exact match) 199 | """ 200 | exact_match = 0 201 | y_truth = [] 202 | y_predict = [] 203 | multi = {} 204 | # First pass - process text sections with single causal relations and store others in `multi` dict() 205 | for t, p in zip(truth, predict): 206 | # Process Exact Match 207 | exact_match += 1 if all([x == y for x, y in zip(t.labels, p.labels)]) else 0 208 | # PRF: Text section with multiple causal relationship ? 209 | if t.index.count('.') == 2: 210 | # extract root index and add to the list to be processed later 211 | root_index = '.'.join(t.index.split('.')[:-1]) 212 | if root_index in multi: 213 | multi[root_index][0].append(t.labels) 214 | multi[root_index][1].append(p.labels) 215 | else: 216 | multi[root_index] = [[t.labels], [p.labels]] 217 | else: 218 | # Accumulate data for precision, recall, f1 scores 219 | y_truth.extend(t.labels) 220 | y_predict.extend(p.labels) 221 | # Second pass - deal with text sections having multiple causal relations 222 | for index, section in multi.items(): 223 | # section[0] list of possible truth labels 224 | # section[1] list of predicted labels 225 | candidates = section[1] 226 | # for each possible combination of truth labels - try to find the best match in predicted labels 227 | # then repeat, removing this match from the list of remaining predicted labels 228 | for t in section[0]: 229 | best = None 230 | for p in candidates: 231 | f1 = metrics.f1_score(t, p, labels=classes, average='weighted', zero_division=0) 232 | if best is None or f1 > best[1]: 233 | best = (p, f1) 234 | # Use best to add to global evaluation 235 | y_truth.extend(t) 236 | y_predict.extend(best[0]) 237 | # Remove best from list of candidate for next iteration 238 | candidates.remove(best[0]) 239 | # Ensure all candidate predictions have been reviewed 240 | assert len(candidates) == 0 241 | 242 | precision, recall, f1, _ = metrics.precision_recall_fscore_support(y_truth, y_predict, 243 | labels=classes, 244 | average='weighted', 245 | zero_division=0) 246 | 247 | import numpy as np 248 | """ 249 | Sklearn Multiclass confusion matrix is: 250 | y_true = ["cat", "ant", "cat", "cat", "ant", "bird"] 251 | y_pred = ["ant", "ant", "cat", "cat", "ant", "cat"] 252 | cmc = multilabel_confusion_matrix(y_true, y_pred, labels=["ant", "bird", "cat"]) 253 | cmc 254 | array([[[3, 1], 255 | [0, 2]], 256 | 257 | [[5, 0], 258 | [1, 0]], 259 | 260 | [[2, 1], 261 | [1, 2]]]) 262 | 263 | cm_ant = cmc[0] 264 | tp = cm_ant[1,1] has been classified as ant and is ant 265 | fp = cm_ant[0,1] has been classified as ant and is sth else 266 | fn = cm_ant[1,0] has been classified as sth else than ant and is ant 267 | tn = cm_ant[0,0] has been classified as sth else than ant and is something else 268 | 269 | """ 270 | print(' ') 271 | print('raw metrics') 272 | print("#-------------------------------------------------------------------------------------#") 273 | 274 | MCM = metrics.multilabel_confusion_matrix(y_truth, y_predict, labels=classes) 275 | tp_sum = MCM[:, 1, 1] 276 | print('tp_sum ', tp_sum) 277 | pred_sum = tp_sum + MCM[:, 0, 1] 278 | print('predicted in the classes ', pred_sum) 279 | true_sum = tp_sum + MCM[:, 1, 0] 280 | print('actually in the classes (tp + fn: support) ', true_sum) 281 | 282 | """beta : float, 1.0 by default 283 | The strength of recall versus precision in the F-score.""" 284 | 285 | precision_ = tp_sum / pred_sum 286 | recall_ = tp_sum / true_sum 287 | beta = 1.0 288 | beta2 = beta ** 2 289 | denom = beta2 * precision_ + recall_ 290 | print('denom', denom) 291 | denom[denom == 0.] = 1 # avoid division by 0 292 | weights = true_sum 293 | print("weights", weights) 294 | print(' ') 295 | print("#-------------------------------------------------------------------------------------#") 296 | print('macro scores') 297 | f_score_ = (1 + beta2) * precision_ * recall_ / denom 298 | print('macro precision ', sum(precision_) / 3, 'macro recall', sum(recall_) / 3, 'macro f_score ', 299 | sum(f_score_) / 3) 300 | 301 | ## recompute for average from source 302 | 303 | precision_ = np.average(precision_, weights=weights) 304 | recall_ = np.average(recall_, weights=weights) 305 | f_score_ = np.average(f_score_, weights=weights) 306 | print(' ') 307 | print("#-------------------------------------------------------------------------------------#") 308 | print('recomputed weighted metrics ') 309 | print('weighted precision', precision_, 'weighted recall', recall_, 'weighted_fscore', f_score_) 310 | print(' ') 311 | print("#-------------------------------------------------------------------------------------#") 312 | print('classification report') 313 | print(metrics.classification_report(y_truth, y_predict, target_names=classes)) 314 | 315 | logging.debug(f'SKLEARN EVAL: {f1}, {precision}, {recall}') 316 | 317 | return precision, recall, f1, float(exact_match) / float(len(truth)) 318 | 319 | 320 | Task2Data = namedtuple('Task2Data', ['index', 'text', 'cause', 'effect', 'labels']) 321 | 322 | 323 | def get_data(csv_lines): 324 | """ 325 | Retrieve Task 2 data from CSV content (separator is ';') as a list of (index, text, cause, effect). 326 | :param csv_lines: 327 | :return: list of Task2Data(index, text, cause, effect, labels) 328 | """ 329 | result = [] 330 | for line in csv_lines: 331 | line = line.rstrip('\n') 332 | 333 | index, text, cause, effect = line.split(';')[:4] 334 | 335 | text = text.lstrip() 336 | cause = cause.lstrip() 337 | effect = effect.lstrip() 338 | 339 | _, labels = zip(*encode_causal_tokens(text, cause, effect)) 340 | 341 | result.append(Task2Data(index, text, cause, effect, labels)) 342 | 343 | return result 344 | 345 | 346 | def evaluate_files(gold_file, submission_file, output_file=None): 347 | """ 348 | Evaluate Precision, Recall, F1 scores between gold_file and submission_file 349 | If output_file is provided, scores are saved in this file and printed to std output. 350 | :param gold_file: path to reference data 351 | :param submission_file: path to submitted data 352 | :param output_file: path to output file as expected by Codalab competition framework 353 | :return: 354 | """ 355 | if os.path.exists(gold_file) and os.path.exists(submission_file): 356 | with open(gold_file, 'r', encoding='utf-8') as fp: 357 | ref_csv = fp.readlines() 358 | with open(submission_file, 'r', encoding='utf-8') as fp: 359 | sub_csv = fp.readlines() 360 | 361 | # Get data (skipping headers) 362 | logging.info('* Loading reference data') 363 | y_true = get_data(ref_csv[1:]) 364 | logging.info('* Loading prediction data') 365 | y_pred = get_data(sub_csv[1:]) 366 | 367 | logging.info(f'Load Data: check data set length = {len(y_true) == len(y_pred)}') 368 | logging.info(f'Load Data: check data set ref. text = {all([x.text == y.text for x, y in zip(y_true, y_pred)])}') 369 | assert len(y_true) == len(y_pred) 370 | assert all([x.text == y.text for x, y in zip(y_true, y_pred)]) 371 | 372 | # Process data using classes: -, C & E 373 | precision, recall, f1, exact_match = evaluate(y_true, y_pred, ['-', 'C', 'E']) 374 | 375 | scores = [ 376 | "F1: %f\n" % f1, 377 | "Recall: %f\n" % recall, 378 | "Precision: %f\n" % precision, 379 | "ExactMatch: %f\n" % exact_match 380 | ] 381 | 382 | for s in scores: 383 | print(s, end='') 384 | if output_file is not None: 385 | with open(output_file, 'w', encoding='utf-8') as fp: 386 | for s in scores: 387 | fp.write(s) 388 | else: 389 | # Submission file most likely being the wrong one - tell which one we are looking for 390 | logging.error(f'{os.path.basename(gold_file)} not found') 391 | 392 | ## Save for control 393 | import pandas as pd 394 | df = pd.DataFrame.from_records(y_true) 395 | df.columns = ['Index', 'Text', 'Cause', 'Effect', 'TRUTH'] 396 | dfpred = pd.DataFrame.from_records(y_pred) 397 | dfpred.columns = ['Index', 'Text', 'Cause', 'Effect', 'PRED'] 398 | df['PRED'] = dfpred['PRED'] 399 | df['TRUTH'] = df['TRUTH'].apply(lambda x: ' '.join(x)) 400 | df['PRED'] = df['PRED'].apply(lambda x: ' '.join(x)) 401 | 402 | ctrlpath = submission_file.split('/') 403 | ctrlpath.pop() 404 | ctrlpath = '/'.join([path_ for path_ in ctrlpath]) 405 | df.to_csv(os.path.join(ctrlpath, 'origin_control.csv'), header=1, index=0) 406 | 407 | 408 | def from_folder(args): 409 | # Folder mode - Codalab usage 410 | submit_dir = os.path.join(args.input, 'res') 411 | truth_dir = os.path.join(args.input, 'ref') 412 | output_dir = args.output 413 | 414 | if not os.path.isdir(submit_dir): 415 | logging.error("%s doesn't exist" % submit_dir) 416 | 417 | if os.path.isdir(submit_dir) and os.path.isdir(truth_dir): 418 | if not os.path.exists(output_dir): 419 | os.makedirs(output_dir) 420 | 421 | o_file = os.path.join(output_dir, 'scores.txt') 422 | 423 | gold_list = os.listdir(truth_dir) 424 | for gold in gold_list: 425 | g_file = os.path.join(truth_dir, gold) 426 | s_file = os.path.join(submit_dir, gold) 427 | 428 | evaluate_files(g_file, s_file, o_file) 429 | 430 | 431 | def from_file(args): 432 | return evaluate_files(args.ref_file, args.pred_file, args.score_file) 433 | 434 | 435 | def main(): 436 | parser = argparse.ArgumentParser() 437 | subparsers = parser.add_subparsers(help='Use from-file for basic mode or from-folder for Codalab compatible mode') 438 | 439 | command1_parser = subparsers.add_parser('from-folder', description='Codalab mode with input and output folders') 440 | command1_parser.set_defaults(func=from_folder) 441 | command1_parser.add_argument('input', help='input folder with ref (reference) and res (result) sub folders') 442 | command1_parser.add_argument('output', help='output folder where score.txt is written') 443 | 444 | command2_parser = subparsers.add_parser('from-file', description='Basic mode with path to input and output files') 445 | command2_parser.set_defaults(func=from_file) 446 | command2_parser.add_argument('--ref_file', default='../../data/fnp2020-fincausal-task2.csv', help='reference file') 447 | command2_parser.add_argument('pred_file', help='prediction file to evaluate') 448 | command2_parser.add_argument('score_file', nargs='?', default=None, 449 | help='path to output score file (or stdout if not provided)') 450 | 451 | logging.basicConfig(level=logging.INFO, 452 | filename=None, 453 | format='%(levelname)-7s| %(message)s') 454 | 455 | args = parser.parse_args() 456 | if 'func' in args: 457 | exit(args.func(args)) 458 | else: 459 | parser.print_usage() 460 | exit(1) 461 | 462 | 463 | if __name__ == '__main__': 464 | main() 465 | 466 | 467 | # Tests, which can be executed with `python -m unittest task2_evaluate`. 468 | class Test(unittest.TestCase): 469 | def _process_test_ok(self, t_text, p_text, t_labels, p_labels, f1, precision, recall, exact_match): 470 | # Load data 471 | y_true = get_data(t_text) 472 | y_pred = get_data(p_text) 473 | with self.subTest(value='encode_causal_truth'): 474 | self.assertEqual(y_true[0].labels, t_labels) 475 | with self.subTest(value='encode_causal_pred'): 476 | self.assertEqual(y_pred[0].labels, p_labels) 477 | # Evaluate precision, recall, f1 and exact matches 478 | result = evaluate(y_true, y_pred, ['-', 'C', 'E']) 479 | # Round result to 2 decimals 480 | result = tuple(map(lambda x: round(x, 2), result)) 481 | with self.subTest(value='evaluate'): 482 | self.assertEqual(result, (precision, recall, f1, exact_match)) 483 | 484 | def test_0(self): 485 | """ Identity """ 486 | self._process_test_ok( 487 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj.; aaa bbbb; hhh i jjjj.\n'], 488 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj.; aaa bbbb; hhh i jjjj.\n'], 489 | ('C', 'C', '-', '-', '-', '-', '-', 'E', 'E', 'E', 'E'), 490 | ('C', 'C', '-', '-', '-', '-', '-', 'E', 'E', 'E', 'E'), 491 | 1.00, 1.00, 1.00, 1.00 492 | ) 493 | 494 | def test_1(self): 495 | """ single failure """ 496 | self._process_test_ok( 497 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; aaa bbbb; hhh i jjjj\n'], 498 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; bbbb; hhh i jjjj\n'], 499 | ('C', 'C', '-', '-', '-', '-', '-', 'E', 'E', 'E'), 500 | ('-', 'C', '-', '-', '-', '-', '-', 'E', 'E', 'E'), 501 | 0.92, 0.9, 0.89, 0.0 502 | ) 503 | 504 | def test_2(self): 505 | """ 1 missed 1 error """ 506 | self._process_test_ok( 507 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; aaa bbbb; hhh i jjjj\n'], 508 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; bbbb; gggg hhh i jjjj\n'], 509 | ('C', 'C', '-', '-', '-', '-', '-', 'E', 'E', 'E'), 510 | ('-', 'C', '-', '-', '-', '-', 'E', 'E', 'E', 'E'), 511 | 0.82, 0.8, 0.79, 0.0 512 | ) 513 | 514 | def test_3(self): 515 | """ total failure """ 516 | self._process_test_ok( 517 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj.; aaa bbbb; hhh i jjjj.\n'], 518 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj.; ccc d; eeee ffff gggg\n'], 519 | ('C', 'C', '-', '-', '-', '-', '-', 'E', 'E', 'E', 'E'), 520 | ('-', '-', 'C', 'C', 'E', 'E', 'E', '-', '-', '-', '-'), 521 | 0.0, 0.0, 0.0, 0.0 522 | ) 523 | 524 | def test_4(self): 525 | """ punctuation """ 526 | self._process_test_ok( 527 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj.; aaa bbbb; hhh i jjjj.\n'], 528 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj.; aaa bbbb; hhh i jjjj\n'], 529 | ('C', 'C', '-', '-', '-', '-', '-', 'E', 'E', 'E', 'E'), 530 | ('C', 'C', '-', '-', '-', '-', '-', 'E', 'E', 'E', '-'), 531 | 0.92, 0.91, 0.91, 0.0 532 | ) 533 | 534 | def test_5(self): 535 | """ 2 missed + punctuation """ 536 | self._process_test_ok( 537 | ['1.0; aaa bbbb ccc d aaa bbbb ccc hhh i jjjj.; aaa bbbb ccc d; hhh i jjjj.\n'], 538 | ['1.0; aaa bbbb ccc d aaa bbbb ccc hhh i jjjj.; aaa bbbb; hhh i jjjj\n'], 539 | ('C', 'C', 'C', 'C', '-', '-', '-', 'E', 'E', 'E', 'E'), 540 | ('C', 'C', '-', '-', '-', '-', '-', 'E', 'E', 'E', '-'), 541 | 0.86, 0.73, 0.74, 0.0 542 | ) 543 | 544 | def test_6(self): 545 | """ non consecutive tokens (out of range) """ 546 | self._process_test_ok( 547 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; aaa bbbb ccc d; hhh i jjjj\n'], 548 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; aaa d; hhh i jjjj\n'], 549 | ('C', 'C', 'C', 'C', '-', '-', '-', 'E', 'E', 'E'), 550 | ('C', '-', '-', '-', '-', '-', '-', 'E', 'E', 'E'), 551 | 0.85, 0.7, 0.66, 0.0 552 | ) 553 | 554 | def test_7(self): 555 | """ non consecutive tokens (in range) """ 556 | self._process_test_ok( 557 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; aaa bbbb ccc d; hhh i jjjj\n'], 558 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; bbbb d; hhh i jjjj\n'], 559 | ('C', 'C', 'C', 'C', '-', '-', '-', 'E', 'E', 'E'), 560 | ('-', 'C', '-', 'C', '-', '-', '-', 'E', 'E', 'E'), 561 | 0.88, 0.8, 0.79, 0.0 562 | ) 563 | 564 | def test_8(self): 565 | """ no effect """ 566 | self._process_test_ok( 567 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; bbbb ccc d; hhh i jjjj\n'], 568 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; bbbb ccc d;\n'], 569 | ('-', 'C', 'C', 'C', '-', '-', '-', 'E', 'E', 'E'), 570 | ('-', 'C', 'C', 'C', '-', '-', '-', '-', '-', '-'), 571 | 0.53, 0.7, 0.59, 0.0 572 | ) 573 | 574 | def test_9(self): 575 | """ no cause """ 576 | self._process_test_ok( 577 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; bbbb ccc d; hhh i jjjj\n'], 578 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; ; bbbb ccc d\n'], 579 | ('-', 'C', 'C', 'C', '-', '-', '-', 'E', 'E', 'E'), 580 | ('-', 'E', 'E', 'E', '-', '-', '-', '-', '-', '-'), 581 | 0.23, 0.4, 0.29, 0.0 582 | ) 583 | 584 | def test_10(self): 585 | """ no cause, no effect """ 586 | self._process_test_ok( 587 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; bbbb ccc d; hhh i jjjj\n'], 588 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; ; \n'], 589 | ('-', 'C', 'C', 'C', '-', '-', '-', 'E', 'E', 'E'), 590 | ('-', '-', '-', '-', '-', '-', '-', '-', '-', '-'), 591 | 0.16, 0.4, 0.23, 0.0 592 | ) 593 | 594 | def test_11(self): 595 | """ all cause """ 596 | self._process_test_ok( 597 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; bbbb ccc d; hhh i jjjj\n'], 598 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; \n'], 599 | ('-', 'C', 'C', 'C', '-', '-', '-', 'E', 'E', 'E'), 600 | ('C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C'), 601 | 0.09, 0.3, 0.14, 0.0 602 | ) 603 | 604 | def test_12(self): 605 | """ half cause """ 606 | self._process_test_ok( 607 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; bbbb ccc d; hhh i jjjj\n'], 608 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; aaa ccc eeee gggg i;\n'], 609 | ('-', 'C', 'C', 'C', '-', '-', '-', 'E', 'E', 'E'), 610 | ('C', '-', 'C', '-', 'C', '-', 'C', '-', 'C', '-'), 611 | 0.14, 0.2, 0.16, 0.0 612 | ) 613 | 614 | def test_13(self): 615 | """ all effect """ 616 | self._process_test_ok( 617 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; bbbb ccc d; hhh i jjjj\n'], 618 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; ; aaa bbbb ccc d eeee ffff gggg hhh i jjjj\n'], 619 | ('-', 'C', 'C', 'C', '-', '-', '-', 'E', 'E', 'E'), 620 | ('E', 'E', 'E', 'E', 'E', 'E', 'E', 'E', 'E', 'E'), 621 | 0.09, 0.3, 0.14, 0.0 622 | ) 623 | -------------------------------------------------------------------------------- /src/evaluation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # Copyright 2020 Guillaume Becquin. 4 | # MODIFIED FOR CAUSE EFFECT EXTRACTION 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import collections 19 | import csv 20 | import json 21 | import logging 22 | import math 23 | from pathlib import Path 24 | from typing import Dict, List, Tuple, Union 25 | 26 | import torch 27 | from torch.nn import Module 28 | from torch.utils.data import SequentialSampler, DataLoader 29 | from tqdm import tqdm 30 | from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast 31 | 32 | from .config import RunConfig 33 | from .fincausal_evaluation.task2_evaluate import encode_causal_tokens, Task2Data 34 | from .data import FinCausalResult, FinCausalFeatures, FinCausalExample 35 | from .preprocessing import load_examples 36 | from .fincausal_evaluation.task2_evaluate import evaluate as official_evaluate 37 | 38 | logger = logging.getLogger(__name__) 39 | 40 | 41 | def to_list(tensor: torch.Tensor) -> List: 42 | return tensor.detach().cpu().tolist() 43 | 44 | 45 | def predict(model: Module, 46 | tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], 47 | device: torch.device, 48 | file_path: Path, 49 | model_type: str, 50 | output_dir: Path, 51 | run_config: RunConfig) -> Tuple[List[FinCausalExample], collections.OrderedDict]: 52 | dataset, examples, features = load_examples(file_path=file_path, 53 | tokenizer=tokenizer, 54 | output_examples=True, 55 | evaluate=True, 56 | run_config=run_config) 57 | 58 | if not output_dir.is_dir(): 59 | output_dir.mkdir(parents=True, exist_ok=True) 60 | 61 | eval_sampler = SequentialSampler(dataset) 62 | eval_dataloader = DataLoader(dataset, 63 | sampler=eval_sampler, 64 | batch_size=run_config.eval_batch_size) 65 | 66 | # Start evaluation 67 | logger.info("***** Running evaluation *****") 68 | logger.info(" Num examples = %d", len(dataset)) 69 | logger.info(" Batch size = %d", run_config.eval_batch_size) 70 | 71 | all_results = [] 72 | sequence_added_tokens = tokenizer.max_len - tokenizer.max_len_single_sentence 73 | 74 | for batch in tqdm(eval_dataloader, desc="Evaluating", position=0, leave=True): 75 | model.eval() 76 | batch = tuple(t.to(device) for t in batch) 77 | 78 | with torch.no_grad(): 79 | inputs = { 80 | "input_ids": batch[0], 81 | "attention_mask": batch[1], 82 | "token_type_ids": batch[2], 83 | } 84 | 85 | if model_type in ["xlm", "roberta", "distilbert", "camembert"]: 86 | del inputs["token_type_ids"] 87 | 88 | example_indices = batch[3] 89 | outputs = model(**inputs) 90 | 91 | for i, example_index in enumerate(example_indices): 92 | eval_feature = features[example_index.item()] 93 | unique_id = int(eval_feature.unique_id) 94 | 95 | output = [to_list(output[i]) for output in outputs] 96 | start_cause_logits, end_cause_logits, start_effect_logits, end_effect_logits = output 97 | result = FinCausalResult(unique_id, 98 | start_cause_logits, end_cause_logits, 99 | start_effect_logits, end_effect_logits) 100 | 101 | all_results.append(result) 102 | 103 | # Compute predictions 104 | predictions = compute_predictions_logits( 105 | examples, 106 | features, 107 | all_results, 108 | output_dir, 109 | sequence_added_tokens, 110 | run_config 111 | ) 112 | 113 | return examples, predictions 114 | 115 | 116 | def evaluate(model: Module, 117 | tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], 118 | device: torch.device, 119 | file_path: Path, 120 | model_type: str, 121 | output_dir: Path, 122 | run_config: RunConfig) -> Dict: 123 | examples, predictions = predict(model=model, 124 | tokenizer=tokenizer, 125 | device=device, 126 | file_path=file_path, 127 | model_type=model_type, 128 | output_dir=output_dir, 129 | run_config=run_config) 130 | 131 | # Compute the F1 and exact scores. 132 | results, correct, wrong = compute_metrics(examples, predictions) 133 | output_prediction_file_correct = output_dir / "predictions_correct.json" 134 | output_prediction_file_wrong = output_dir / "predictions_wrong.json" 135 | 136 | with output_prediction_file_correct.open('w') as writer: 137 | writer.write(json.dumps(correct, indent=4) + "\n") 138 | 139 | with output_prediction_file_wrong.open('w') as writer: 140 | writer.write(json.dumps(wrong, indent=4) + "\n") 141 | 142 | return results 143 | 144 | 145 | def get_data_from_list(input_data: List[List[str]]) -> List[Task2Data]: 146 | """ 147 | :param input_data: list of inputs (example id, text, cause, effect) 148 | :return: list of Task2Data(index, text, cause, effect, labels) 149 | """ 150 | result = [] 151 | for example in input_data: 152 | (index, text, cause, effect) = tuple(example) 153 | 154 | text = text.lstrip() 155 | cause = cause.lstrip() 156 | effect = effect.lstrip() 157 | 158 | _, labels = zip(*encode_causal_tokens(text, cause, effect)) 159 | 160 | result.append(Task2Data(index, text, cause, effect, labels)) 161 | 162 | return result 163 | 164 | 165 | def compute_metrics(examples: List[FinCausalExample], predictions: collections.OrderedDict) \ 166 | -> Tuple[Dict, List[Dict], List[Dict]]: 167 | y_true = [] 168 | y_pred = [] 169 | 170 | for example in examples: 171 | y_true.append([example.example_id, example.context_text, example.cause_text, example.effect_text]) 172 | prediction = predictions[example.example_id] 173 | y_pred.append([example.example_id, example.context_text, prediction['cause_text'], prediction['effect_text']]) 174 | 175 | all_correct = list() 176 | all_wrong = list() 177 | for y_true_ex, y_pred_ex in zip(y_true, y_pred): 178 | if (y_true_ex[2] == y_pred_ex[2]) and (y_true_ex[3] == y_pred_ex[3]): 179 | all_correct.append({'text': y_true_ex[1], 180 | 'cause_true': y_true_ex[2], 181 | 'effect_true': y_true_ex[3], 182 | 'cause_pred': y_pred_ex[2], 183 | 'effect_pred': y_pred_ex[3] 184 | }) 185 | else: 186 | all_wrong.append({'text': y_true_ex[1], 187 | 'cause_true': y_true_ex[2], 188 | 'effect_true': y_true_ex[3], 189 | 'cause_pred': y_pred_ex[2], 190 | 'effect_pred': y_pred_ex[3] 191 | }) 192 | logging.info('* Loading reference data') 193 | y_true = get_data_from_list(y_true) 194 | logging.info('* Loading prediction data') 195 | y_pred = get_data_from_list(y_pred) 196 | logging.info(f'Load Data: check data set length = {len(y_true) == len(y_pred)}') 197 | logging.info(f'Load Data: check data set ref. text = {all([x.text == y.text for x, y in zip(y_true, y_pred)])}') 198 | assert len(y_true) == len(y_pred) 199 | assert all([x.text == y.text for x, y in zip(y_true, y_pred)]) 200 | 201 | precision, recall, f1, exact_match = official_evaluate(y_true, y_pred, ['-', 'C', 'E']) 202 | 203 | scores = [ 204 | "F1: %f\n" % f1, 205 | "Recall: %f\n" % recall, 206 | "Precision: %f\n" % precision, 207 | "ExactMatch: %f\n" % exact_match 208 | ] 209 | for s in scores: 210 | print(s, end='') 211 | 212 | return { 213 | 'F1score:': f1, 214 | 'Precision: ': precision, 215 | 'Recall: ': recall, 216 | 'exact match: ': exact_match 217 | }, all_correct, all_wrong 218 | 219 | 220 | _PrelimPrediction = collections.namedtuple( 221 | "PrelimPrediction", ["feature_index", 222 | "start_index_cause", 223 | "end_index_cause", 224 | "start_logit_cause", 225 | "end_logit_cause", 226 | "start_index_effect", 227 | "end_index_effect", 228 | "start_logit_effect", 229 | "end_logit_effect"] 230 | ) 231 | 232 | _NbestPrediction = collections.namedtuple( 233 | "NbestPrediction", ["text_cause", 234 | "start_index_cause", 235 | "end_index_cause", 236 | "start_logit_cause", 237 | "end_logit_cause", 238 | "text_effect", 239 | "start_index_effect", 240 | "end_index_effect", 241 | "start_logit_effect", 242 | "end_logit_effect"] 243 | ) 244 | 245 | 246 | def filter_impossible_spans(features, 247 | unique_id_to_result: Dict, 248 | n_best_size: int, 249 | max_answer_length: int, 250 | sequence_added_tokens: int, 251 | sentence_boundary_heuristic: bool = False, 252 | full_sentence_heuristic: bool = False, 253 | shared_sentence_heuristic: bool = False, 254 | ) -> List[_PrelimPrediction]: 255 | prelim_predictions = [] 256 | 257 | for (feature_index, feature) in enumerate(features): 258 | result = unique_id_to_result[feature.unique_id] 259 | assert isinstance(feature, FinCausalFeatures) 260 | assert isinstance(result, FinCausalResult) 261 | sentence_offsets = [offset for offset in [feature.sentence_2_offset, feature.sentence_3_offset] if 262 | offset is not None] 263 | start_indexes_cause = _get_best_indexes(result.start_cause_logits, n_best_size) 264 | end_indexes_cause = _get_best_indexes(result.end_cause_logits, n_best_size) 265 | start_logits_cause = result.start_cause_logits 266 | end_logits_cause = result.end_cause_logits 267 | start_indexes_effect = _get_best_indexes(result.start_effect_logits, n_best_size) 268 | end_indexes_effect = _get_best_indexes(result.end_effect_logits, n_best_size) 269 | start_logits_effect = result.start_effect_logits 270 | end_logits_effect = result.end_effect_logits 271 | 272 | for raw_start_index_cause in start_indexes_cause: 273 | for raw_end_index_cause in end_indexes_cause: 274 | cause_pairs = [(raw_start_index_cause, raw_end_index_cause)] 275 | # Heuristic: a effect of a cause cannot span across multiple sentences 276 | if len(sentence_offsets) > 0 and sentence_boundary_heuristic: 277 | for sentence_offset in sentence_offsets: 278 | if raw_start_index_cause < sentence_offset < raw_end_index_cause: 279 | cause_pairs = [(raw_start_index_cause, sentence_offset), 280 | (sentence_offset + 1, raw_end_index_cause)] 281 | for start_index_cause, end_index_cause in cause_pairs: 282 | for raw_start_index_effect in start_indexes_effect: 283 | for raw_end_index_effect in end_indexes_effect: 284 | effect_pairs = [(raw_start_index_effect, raw_end_index_effect)] 285 | # Heuristic: a effect of a cause cannot span across multiple sentences 286 | if len(sentence_offsets) > 0 and sentence_boundary_heuristic: 287 | for sentence_offset in sentence_offsets: 288 | if raw_start_index_effect < sentence_offset < raw_end_index_effect: 289 | effect_pairs = [(raw_start_index_effect, sentence_offset), 290 | (sentence_offset + 1, raw_end_index_effect)] 291 | for start_index_effect, end_index_effect in effect_pairs: 292 | if (start_index_cause <= start_index_effect) and ( 293 | end_index_cause >= start_index_effect): 294 | continue 295 | if (start_index_effect <= start_index_cause) and ( 296 | end_index_effect >= start_index_cause): 297 | continue 298 | if start_index_effect >= len(feature.tokens) or start_index_cause >= len( 299 | feature.tokens): 300 | continue 301 | if end_index_effect >= len(feature.tokens) or end_index_cause >= len(feature.tokens): 302 | continue 303 | if start_index_effect not in feature.token_to_orig_map or \ 304 | start_index_cause not in feature.token_to_orig_map: 305 | continue 306 | if end_index_effect not in feature.token_to_orig_map or \ 307 | end_index_cause not in feature.token_to_orig_map: 308 | continue 309 | if (not feature.token_is_max_context.get(start_index_effect, False)) or \ 310 | (not feature.token_is_max_context.get(start_index_cause, False)): 311 | continue 312 | if end_index_cause < start_index_cause: 313 | continue 314 | if end_index_effect < start_index_effect: 315 | continue 316 | length_cause = end_index_cause - start_index_cause + 1 317 | length_effect = end_index_effect - start_index_effect + 1 318 | if length_cause > max_answer_length: 319 | continue 320 | if length_effect > max_answer_length: 321 | continue 322 | 323 | # Heuristics extending the prediction spans 324 | if full_sentence_heuristic or shared_sentence_heuristic: 325 | num_tokens = len(feature.tokens) 326 | all_sentence_offsets = [sequence_added_tokens] + \ 327 | [offset + 1 for offset in sentence_offsets] + \ 328 | [num_tokens] 329 | cause_sentences = [] 330 | effect_sentences = [] 331 | for sentence_idx in range(len(all_sentence_offsets) - 1): 332 | sentence_start, sentence_end = all_sentence_offsets[sentence_idx], \ 333 | all_sentence_offsets[sentence_idx + 1] 334 | if sentence_start <= start_index_cause < sentence_end: 335 | cause_sentences.append(sentence_idx) 336 | if sentence_start <= start_index_effect < sentence_end: 337 | effect_sentences.append(sentence_idx) 338 | 339 | # Heuristic (first rule): if a sentence contains only 1 clause the clause is 340 | # extended to the entire sentence. 341 | if set(cause_sentences).isdisjoint(set(effect_sentences)) \ 342 | and full_sentence_heuristic: 343 | start_index_cause = min( 344 | [all_sentence_offsets[sent] for sent in cause_sentences]) 345 | end_index_cause = max( 346 | [all_sentence_offsets[sent + 1] - 1 for sent in cause_sentences]) 347 | start_index_effect = min( 348 | [all_sentence_offsets[sent] for sent in effect_sentences]) 349 | end_index_effect = max( 350 | [all_sentence_offsets[sent + 1] - 1 for sent in effect_sentences]) 351 | # Heuristic (third rule): if a sentence contains only 2 clauses the span is 352 | # extended as much as possible. 353 | if not set(cause_sentences).isdisjoint(set(effect_sentences)) \ 354 | and shared_sentence_heuristic \ 355 | and len(cause_sentences) == 1 \ 356 | and len(effect_sentences) == 1: 357 | if start_index_cause < start_index_effect: 358 | start_index_cause = min( 359 | [all_sentence_offsets[sent] for sent in cause_sentences]) 360 | end_index_effect = max( 361 | [all_sentence_offsets[sent + 1] - 1 for sent in effect_sentences]) 362 | else: 363 | start_index_effect = min( 364 | [all_sentence_offsets[sent] for sent in effect_sentences]) 365 | end_index_cause = max( 366 | [all_sentence_offsets[sent + 1] - 1 for sent in cause_sentences]) 367 | 368 | prelim_predictions.append( 369 | _PrelimPrediction( 370 | feature_index=feature_index, 371 | start_index_cause=start_index_cause, 372 | end_index_cause=end_index_cause, 373 | start_logit_cause=start_logits_cause[start_index_cause], 374 | end_logit_cause=end_logits_cause[end_index_cause], 375 | start_index_effect=start_index_effect, 376 | end_index_effect=end_index_effect, 377 | start_logit_effect=start_logits_effect[start_index_effect], 378 | end_logit_effect=end_logits_effect[end_index_effect] 379 | ) 380 | ) 381 | return prelim_predictions 382 | 383 | 384 | def get_predictions(preliminary_predictions: List[_PrelimPrediction], n_best_size: int, 385 | features: List[FinCausalFeatures], example: FinCausalExample) -> List[_NbestPrediction]: 386 | seen_predictions_cause = {} 387 | seen_predictions_effect = {} 388 | nbest = [] 389 | for prediction in preliminary_predictions: 390 | if len(nbest) >= n_best_size: 391 | break 392 | feature = features[prediction.feature_index] 393 | if prediction.start_index_cause > 0: # this is a non-null prediction 394 | orig_doc_start_cause = feature.token_to_orig_map[prediction.start_index_cause] 395 | orig_doc_end_cause = feature.token_to_orig_map[prediction.end_index_cause] 396 | orig_doc_start_cause_char = example.word_to_char_mapping[orig_doc_start_cause] 397 | if orig_doc_end_cause < len(example.word_to_char_mapping) - 1: 398 | orig_doc_end_cause_char = example.word_to_char_mapping[orig_doc_end_cause + 1] 399 | else: 400 | orig_doc_end_cause_char = len(example.context_text) 401 | final_text_cause = example.context_text[orig_doc_start_cause_char: orig_doc_end_cause_char] 402 | final_text_cause = final_text_cause.strip() 403 | 404 | orig_doc_start_effect = feature.token_to_orig_map[prediction.start_index_effect] 405 | orig_doc_end_effect = feature.token_to_orig_map[prediction.end_index_effect] 406 | orig_doc_start_effect_char = example.word_to_char_mapping[orig_doc_start_effect] 407 | if orig_doc_end_effect < len(example.word_to_char_mapping) - 1: 408 | orig_doc_end_effect_char = example.word_to_char_mapping[orig_doc_end_effect + 1] 409 | else: 410 | orig_doc_end_effect_char = len(example.context_text) 411 | final_text_effect = example.context_text[orig_doc_start_effect_char: orig_doc_end_effect_char] 412 | final_text_effect = final_text_effect.strip() 413 | 414 | if final_text_cause in seen_predictions_cause and final_text_effect in seen_predictions_effect: 415 | continue 416 | 417 | seen_predictions_cause[final_text_cause] = True 418 | seen_predictions_cause[final_text_effect] = True 419 | else: 420 | final_text_cause = final_text_effect = "" 421 | seen_predictions_cause[final_text_cause] = True 422 | seen_predictions_cause[final_text_effect] = True 423 | orig_doc_start_cause = prediction.start_index_cause 424 | orig_doc_end_cause = prediction.end_index_cause 425 | orig_doc_start_effect = prediction.end_index_effect 426 | orig_doc_end_effect = prediction.end_index_effect 427 | 428 | nbest.append( 429 | _NbestPrediction(text_cause=final_text_cause, 430 | start_logit_cause=prediction.start_logit_cause, 431 | end_logit_cause=prediction.end_logit_cause, 432 | start_index_cause=orig_doc_start_cause, 433 | end_index_cause=orig_doc_end_cause, 434 | text_effect=final_text_effect, 435 | start_logit_effect=prediction.start_logit_effect, 436 | end_logit_effect=prediction.end_logit_effect, 437 | start_index_effect=orig_doc_start_effect, 438 | end_index_effect=orig_doc_end_effect, 439 | )) 440 | return nbest 441 | 442 | 443 | def compute_predictions_logits( 444 | all_examples: List[FinCausalExample], 445 | all_features: List[FinCausalFeatures], 446 | all_results: List[FinCausalResult], 447 | output_dir: Path, 448 | sequence_added_tokens: int, 449 | run_config: RunConfig) -> collections.OrderedDict: 450 | example_index_to_features = collections.defaultdict(list) 451 | for feature in all_features: 452 | example_index_to_features[feature.example_index].append(feature) 453 | 454 | unique_id_to_result = {} 455 | for result in all_results: 456 | unique_id_to_result[result.unique_id] = result 457 | 458 | all_predictions = collections.OrderedDict() 459 | all_nbest_json = collections.OrderedDict() 460 | 461 | for (example_index, example) in enumerate(all_examples): 462 | features = example_index_to_features[example_index] 463 | suffix_index = 0 464 | if example.example_id.count('.') == 2 and run_config.top_n_sentences: 465 | suffix_index = int(example.example_id.split('.')[-1]) 466 | prelim_predictions = filter_impossible_spans(features, 467 | unique_id_to_result, 468 | run_config.n_best_size, 469 | run_config.max_answer_length, 470 | sequence_added_tokens, 471 | run_config.sentence_boundary_heuristic, 472 | run_config.full_sentence_heuristic, 473 | run_config.shared_sentence_heuristic, ) 474 | prelim_predictions = sorted(list(set(prelim_predictions)), 475 | key=lambda x: (x.start_logit_cause + x.end_logit_cause + 476 | x.start_logit_effect + x.end_logit_effect), 477 | reverse=True) 478 | 479 | nbest = get_predictions(prelim_predictions, run_config.n_best_size, features, example) 480 | 481 | # In very rare edge cases we could have no valid predictions. So we 482 | # just create a none prediction in this case to avoid failure. 483 | if not nbest: 484 | nbest.append(_NbestPrediction(text_cause="empty", start_logit_cause=0.0, end_logit_cause=0.0, 485 | text_effect="empty", start_logit_effect=0.0, end_logit_effect=0.0, 486 | start_index_effect=0, end_index_effect=0, 487 | start_index_cause=0, end_index_cause=0)) 488 | 489 | assert len(nbest) >= 1 490 | 491 | total_scores = [] 492 | best_non_null_entry = None 493 | for entry in nbest: 494 | total_scores.append(entry.start_logit_cause + entry.end_logit_cause + 495 | entry.start_logit_effect + entry.end_logit_effect) 496 | if not best_non_null_entry: 497 | if entry.text_cause and entry.text_effect: 498 | best_non_null_entry = entry 499 | 500 | probabilities = _compute_softmax(total_scores) 501 | 502 | nbest_json = [] 503 | current_example_spans = [] 504 | for (i, entry) in enumerate(nbest): 505 | output = collections.OrderedDict() 506 | output["text"] = example.context_text 507 | output["probability"] = probabilities[i] 508 | output["cause_text"] = entry.text_cause 509 | output["cause_start_index"] = entry.start_index_cause 510 | output["cause_end_index"] = entry.end_index_cause 511 | output["cause_start_score"] = entry.start_logit_cause 512 | output["cause_end_score"] = entry.end_logit_cause 513 | output["effect_text"] = entry.text_effect 514 | output["effect_start_score"] = entry.start_logit_effect 515 | output["effect_end_score"] = entry.end_logit_effect 516 | output["effect_start_index"] = entry.start_index_effect 517 | output["effect_end_index"] = entry.end_index_effect 518 | new_span = SpanCombination(start_cause=entry.start_index_cause, end_cause=entry.end_index_cause, 519 | start_effect=entry.start_index_effect, end_effect=entry.end_index_effect) 520 | output["is_new"] = all([new_span != other for other in current_example_spans]) 521 | nbest_json.append(output) 522 | current_example_spans.append(new_span) 523 | 524 | assert len(nbest_json) >= 1 525 | if suffix_index > 0: 526 | suffix_index -= 1 527 | all_predictions[example.example_id] = {"text": nbest_json[suffix_index]["text"], 528 | "cause_text": nbest_json[suffix_index]["cause_text"], 529 | "effect_text": nbest_json[suffix_index]["effect_text"]} 530 | all_nbest_json[example.example_id] = nbest_json 531 | 532 | output_prediction_file = output_dir / "predictions.json" 533 | csv_output_prediction_file = output_dir / "predictions.csv" 534 | output_nbest_file = output_dir / "nbest_predictions.json" 535 | 536 | logger.info("Writing predictions to: %s" % output_prediction_file) 537 | with open(output_prediction_file, "w") as writer: 538 | writer.write(json.dumps(all_predictions, indent=4) + "\n") 539 | 540 | with open(csv_output_prediction_file, "w", encoding='utf-8', newline='') as writer: 541 | csv_writer = csv.writer(writer, delimiter=';') 542 | csv_writer.writerow(['Index', 'Text', 'Cause', 'Effect']) 543 | for (example_id, prediction) in all_predictions.items(): 544 | csv_writer.writerow([example_id, prediction['text'], prediction['cause_text'], prediction['effect_text']]) 545 | 546 | logger.info("Writing nbest to: %s" % output_nbest_file) 547 | with open(output_nbest_file, "w") as writer: 548 | writer.write(json.dumps(all_nbest_json, indent=4) + "\n") 549 | 550 | return all_predictions 551 | 552 | 553 | class SpanCombination: 554 | def __init__(self, start_cause: int, end_cause: int, start_effect: int, end_effect: int): 555 | self.start_cause = start_cause 556 | self.start_effect = start_effect 557 | self.end_cause = end_cause 558 | self.end_effect = end_effect 559 | 560 | def __eq__(self, other): 561 | overlapping_cause = (self.start_cause <= other.start_cause <= self.end_cause) or \ 562 | (self.start_cause <= other.end_cause <= self.end_cause) or \ 563 | (self.start_effect <= other.start_cause <= self.end_effect) or \ 564 | (self.start_effect <= other.end_cause <= self.end_effect) 565 | overlapping_effect = (self.start_effect <= other.start_effect <= self.end_effect) or \ 566 | (self.start_effect <= other.end_effect <= self.end_effect) or \ 567 | (self.start_cause <= other.start_effect <= self.end_cause) or \ 568 | (self.start_cause <= other.end_effect <= self.end_cause) 569 | return overlapping_cause and overlapping_effect 570 | 571 | 572 | def _get_best_indexes(logits, n_best_size) -> List[int]: 573 | """Get the n-best logits from a list.""" 574 | index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True) 575 | 576 | best_indexes = [] 577 | for i in range(len(index_and_score)): 578 | if i >= n_best_size: 579 | break 580 | best_indexes.append(index_and_score[i][0]) 581 | return best_indexes 582 | 583 | 584 | def _compute_softmax(scores) -> List[float]: 585 | """Compute softmax probability over raw logits.""" 586 | if not scores: 587 | return [] 588 | 589 | max_score = None 590 | for score in scores: 591 | if max_score is None or score > max_score: 592 | max_score = score 593 | 594 | exp_scores = [] 595 | total_sum = 0.0 596 | for score in scores: 597 | x = math.exp(score - max_score) 598 | exp_scores.append(x) 599 | total_sum += x 600 | 601 | probabilities = [] 602 | for score in exp_scores: 603 | probabilities.append(score / total_sum) 604 | return probabilities 605 | --------------------------------------------------------------------------------