├── src
├── __init__.py
├── models
│ ├── __init__.py
│ ├── roberta.py
│ ├── albert.py
│ ├── distilbert.py
│ ├── bert.py
│ └── xlnet.py
├── fincausal_evaluation
│ ├── __init__.py
│ └── task2_evaluate.py
├── logging.py
├── config.py
├── data.py
├── training.py
├── preprocessing.py
└── evaluation.py
├── utils
├── __init__.py
└── split_dataset.py
├── output
└── README.md
├── data
└── README.md
├── .gitignore
├── README.md
├── requirements.txt
├── main.py
└── LICENSE
/src/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/fincausal_evaluation/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/output/README.md:
--------------------------------------------------------------------------------
1 | The model outputs are stored in this repository.
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | Place data in this repository:
2 | - Trial dataset (`fnp2020-fincausal-task2.csv`)
3 | - Practice dataset (`fnp2020-fincausal2-task2.csv`)
4 | - Test dataset (`task2.csv`)
--------------------------------------------------------------------------------
/utils/split_dataset.py:
--------------------------------------------------------------------------------
1 |
2 | import pandas as pd
3 | from pathlib import Path
4 | import os
5 | import sys
6 | from sklearn.model_selection import train_test_split
7 |
8 | if __name__ == '__main__':
9 |
10 | fincausal_data_path = Path(os.environ.get('FINCAUSAL_DATA_PATH',
11 | os.path.dirname(os.path.realpath(sys.argv[0])) + '/../data'))
12 |
13 | input_file = fincausal_data_path / "fnp2020-fincausal-task2.csv"
14 | train_output = fincausal_data_path / "fnp2020-train.csv"
15 | dev_output = fincausal_data_path / "fnp2020-eval.csv"
16 | size = 0.1
17 | seed = 42
18 |
19 | data = pd.read_csv(input_file, delimiter=';', header=0)
20 |
21 | train, test = train_test_split(data, test_size=size, random_state=seed)
22 |
23 | train.to_csv(train_output, header=True, sep=';', index=None)
24 | test.to_csv(dev_output, header=True, sep=';', index=None)
25 |
--------------------------------------------------------------------------------
/src/logging.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020, Guillaume Becquin
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Dict
16 |
17 | from src.config import ModelConfigurations, RunConfig
18 |
19 |
20 | def initialize_log_dict(model_config: ModelConfigurations,
21 | run_config: RunConfig,
22 | model_tokenizer_mapping: Dict) -> Dict:
23 | model_type = model_config.value[0]
24 | model_class, tokenizer_class = model_tokenizer_mapping[model_type]
25 |
26 | return {'MODEL_TYPE': model_config.value[0],
27 | 'MODEL_CLASS': model_class.__name__,
28 | 'TOKENIZER_CLASS': tokenizer_class.__name__,
29 | 'MODEL_NAME_OR_PATH': model_config.value[1],
30 | 'DO_TRAIN': run_config.do_train,
31 | 'DO_EVAL': run_config.do_eval,
32 | 'DO_LOWER_CASE': model_config.value[2],
33 | 'MAX_SEQ_LENGTH': run_config.max_seq_length,
34 | 'DOC_STRIDE': run_config.doc_stride,
35 | 'PER_GPU_BATCH_SIZE': run_config.train_batch_size,
36 | 'GRADIENT_ACCUMULATION_STEPS': run_config.gradient_accumulation_steps,
37 | 'WARMUP_STEPS': run_config.warmup_steps,
38 | 'LEARNING_RATE': run_config.learning_rate,
39 | 'NUM_TRAIN_EPOCHS': run_config.num_train_epochs,
40 | 'PER_GPU_EVAL_BATCH_SIZE': run_config.eval_batch_size,
41 | 'N_BEST_SIZE': run_config.n_best_size,
42 | 'MAX_ANSWER_LENGTH': run_config.max_answer_length,
43 | 'SENTENCE_BOUNDARY_HEURISTIC': run_config.sentence_boundary_heuristic,
44 | 'FULL_SENTENCE_HEURISTIC': run_config.full_sentence_heuristic,
45 | 'SHARED_SENTENCE_HEURISTIC': run_config.shared_sentence_heuristic,
46 | 'OPTIMIZER': str(run_config.optimizer_class),
47 | 'WEIGHT_DECAY': run_config.weight_decay,
48 | 'SCHEDULER_FUNCTION': str(run_config.scheduler_function)
49 | }
50 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 | /BACKUP/
131 | .idea/
132 | /data/*.csv
133 | /output/**
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Span-based Causality Extraction for Financial Documents
2 | This repository contains the supporting code for the FinCausal 2020 submission titled `Span-based Causality Extraction for Financial Documents`. The model extracts cause and effect spans from financial documents, for example:
3 |
4 | |Text|Cause|Effect|
5 | | ------------- |:-----------------------:| ---------------------:|
6 | |Boussard Gavaudan Investment Management LLP bought a new position in shares of GENFIT S A/ADR in the second quarter worth about $199,000. Morgan Stanley increased its stake in shares of GENFIT S A/ADR by 24.4% in the second quarter.Morgan Stanley now owns 10,700 shares of the company’s stock worth $211,000 after purchasing an additional 2,100 shares during the period|Morgan Stanley increased its stake in shares of GENFIT S A/ADR by 24.4% in the second quarter|Morgan Stanley now owns 10,700 shares of the company’s stock worth $211,000 after purchasing an additional 2,100 shares during the period.|
7 | |Zhao found himself 60 million yuan indebted after losing 9,000 BTC in a single day (February 10, 2014)|losing 9,000 BTC in a single day (February 10, 2014)|Zhao found himself 60 million yuan indebted|
8 |
9 | (sample from the [task description: Data Processing and Metrics for FinCausal Shared Task, 2020, Mariko et al.](https://drive.google.com/file/d/1LUTJVj9ItJMZzKvy1LrCTuBK2SITzr1z/view))
10 |
11 | The system ranked 2nd on the official evaluation board, and reached the following performance in post-evaluation:
12 |
13 | #### Post-evaluation performance
14 |
15 | |Metric|score|
16 | |:-------------|:-------------:|
17 | |weighted-averaged F1|95.01%|
18 | |Exact matches| 83.34%|
19 | |weighted-averaged Precision| 95.01%|
20 | |weighted-averaged Recall| 95.01%|
21 |
22 | #### Official evaluation performance
23 |
24 | |Metric|score|
25 | |:-------------|:-------------:|
26 | |weighted-averaged F1|94.66%|
27 | |Exact matches| 73.66%|
28 | |weighted-averaged Precision| 94.66%|
29 | |weighted-averaged Recall| 94.66%|
30 |
31 | The system is based on a RoBERTa span-extraction model (similar to Question Answering architecture), a full description of the system is available in the related system description. If you find this system useful, please cite us:
32 |
33 | ```latex
34 | @inproceedings{
35 | Becquin-fincausal-2020,
36 | title ={{GBe at FinCausal 2020, Task 2: Span-based Causality Extraction for Financial Documents}},
37 | author = {Becquin, Guillaume},
38 | booktitle ={{The 1st Joint Workshop on Financial Narrative Processing and MultiLing Financial Summarisation (FNP-FNS 2020}},
39 | year = {2020},
40 | address = {Barcelona, Spain}
41 | }
42 | ```
43 |
44 | ## Instructions
45 |
46 | 0. Install requirements provided in `reuiquirements.py` (it is advised to use a virtual environment)
47 | 1. Generate the train / development data split running running the `./utils/split_dataset.py`
48 |
49 | ### Training:
50 | 2. run `main.py --train`
51 |
52 | ### Evaluation:
53 | 2. run `main.py --eval`
54 |
55 | ### Generate predictions:
56 | 2. run `main.py --test`
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | Package Version Location
2 | ---------------------- ------------ --------------------------
3 | absl-py 0.9.0
4 | backcall 0.1.0
5 | blis 0.4.1
6 | boto3 1.12.39
7 | botocore 1.15.39
8 | cachetools 4.1.0
9 | catalogue 1.0.0
10 | certifi 2020.4.5.1
11 | chardet 3.0.4
12 | click 7.1.1
13 | colorama 0.4.3
14 | cycler 0.10.0
15 | cymem 2.0.3
16 | decorator 4.4.2
17 | docutils 0.15.2
18 | en-core-web-lg 2.2.5
19 | filelock 3.0.12
20 | funcy 1.14
21 | google-auth 1.13.1
22 | google-auth-oauthlib 0.4.1
23 | grpcio 1.28.1
24 | idna 2.9
25 | importlib-metadata 1.6.1
26 | ipykernel 5.3.0
27 | ipython 7.15.0
28 | ipython-genutils 0.2.0
29 | jedi 0.17.0
30 | jmespath 0.9.5
31 | joblib 0.14.1
32 | jupyter-client 6.1.3
33 | jupyter-core 4.6.3
34 | kiwisolver 1.2.0
35 | Library 0.0.0
36 | Markdown 3.2.1
37 | matplotlib 3.2.2
38 | murmurhash 1.0.2
39 | nltk 3.5
40 | numpy 1.18.2
41 | oauthlib 3.1.0
42 | packaging 20.4
43 | pandas 1.0.3
44 | parso 0.7.0
45 | pickleshare 0.7.5
46 | pip 20.0.2
47 | plac 1.1.3
48 | preshed 3.0.2
49 | prompt-toolkit 3.0.5
50 | protobuf 3.11.3
51 | pyasn1 0.4.8
52 | pyasn1-modules 0.2.8
53 | Pygments 2.6.1
54 | pyparsing 2.4.7
55 | pysbd 0.2.3
56 | python-crfsuite 0.9.7
57 | python-dateutil 2.8.1
58 | pytorch-ranger 0.1.1
59 | pytz 2019.3
60 | pywin32 227
61 | pyzmq 19.0.1
62 | regex 2020.4.4
63 | requests 2.23.0
64 | requests-oauthlib 1.3.0
65 | rsa 4.0
66 | s3transfer 0.3.3
67 | sacremoses 0.0.38
68 | scikit-learn 0.22.2.post1
69 | scipy 1.4.1
70 | seaborn 0.10.1
71 | sentencepiece 0.1.85
72 | setuptools 46.1.3
73 | six 1.14.0
74 | sklearn 0.0
75 | spacy 2.2.4
76 | srsly 1.0.2
77 | tensorboard 2.2.0
78 | tensorboard-plugin-wit 1.6.0.post3
79 | thinc 7.4.0
80 | tokenizers 0.7.0rc3
81 | torch 1.4.0
82 | torch-optimizer 0.0.1a12
83 | tornado 6.0.4
84 | tqdm 4.45.0
85 | traitlets 4.3.3
86 | transformers 3.0.1 e:\coding\transformers\src
87 | urllib3 1.25.8
88 | wasabi 0.6.0
89 | wcwidth 0.2.3
90 | Werkzeug 1.0.1
91 | wheel 0.34.2
92 | zipp 3.1.0
93 |
--------------------------------------------------------------------------------
/src/models/roberta.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3 | # Copyright 2020 Guillaume Becquin.
4 | # MODIFIED FOR CAUSE EFFECT EXTRACTION
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 | from torch import nn
19 | from transformers import BertPreTrainedModel, RobertaModel
20 |
21 |
22 | class RoBERTaForCauseEffect(BertPreTrainedModel):
23 | def __init__(self, config):
24 | super().__init__(config)
25 |
26 | self.roberta = RobertaModel(config)
27 | self.cause_outputs = nn.Linear(config.hidden_size, config.num_labels)
28 | self.effect_outputs = nn.Linear(config.hidden_size, config.num_labels)
29 | assert config.num_labels == 2
30 | self.init_weights()
31 |
32 | def forward(
33 | self,
34 | input_ids=None,
35 | attention_mask=None,
36 | token_type_ids=None,
37 | position_ids=None,
38 | head_mask=None,
39 | inputs_embeds=None,
40 | start_cause_positions=None,
41 | end_cause_positions=None,
42 | start_effect_positions=None,
43 | end_effect_positions=None,
44 | ):
45 | bert_output = self.roberta(
46 | input_ids=input_ids,
47 | attention_mask=attention_mask,
48 | token_type_ids=token_type_ids,
49 | position_ids=position_ids,
50 | head_mask=head_mask,
51 | inputs_embeds=inputs_embeds
52 | )
53 | hidden_states = bert_output[0] # (bs, max_query_len, dim)
54 | cause_logits = self.cause_outputs(hidden_states) # (bs, max_query_len, 2)
55 | effect_logits = self.effect_outputs(hidden_states) # (bs, max_query_len, 2)
56 | start_cause_logits, end_cause_logits = cause_logits.split(1, dim=-1)
57 | start_effect_logits, end_effect_logits = effect_logits.split(1, dim=-1)
58 |
59 | start_cause_logits = start_cause_logits.squeeze(-1) # (bs, max_query_len)
60 | end_cause_logits = end_cause_logits.squeeze(-1) # (bs, max_query_len)
61 | start_effect_logits = start_effect_logits.squeeze(-1) # (bs, max_query_len)
62 | end_effect_logits = end_effect_logits.squeeze(-1) # (bs, max_query_len)
63 |
64 | outputs = (start_cause_logits, end_cause_logits, start_effect_logits, end_effect_logits,) \
65 | + bert_output[2:]
66 | if start_cause_positions is not None \
67 | and end_cause_positions is not None \
68 | and start_effect_positions is not None \
69 | and end_effect_positions is not None:
70 | # sometimes the start/end positions are outside our model inputs, we ignore these terms
71 | ignored_index = start_cause_logits.size(1)
72 | start_cause_positions.clamp_(0, ignored_index)
73 | end_cause_positions.clamp_(0, ignored_index)
74 | start_effect_positions.clamp_(0, ignored_index)
75 | end_effect_positions.clamp_(0, ignored_index)
76 |
77 | loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
78 | start_cause_loss = loss_fct(start_cause_logits, start_cause_positions)
79 | end_cause_loss = loss_fct(end_cause_logits, end_cause_positions)
80 | start_effect_loss = loss_fct(start_effect_logits, start_effect_positions)
81 | end_effect_loss = loss_fct(end_effect_logits, end_effect_positions)
82 | total_loss = (start_cause_loss + end_cause_loss + start_effect_loss + end_effect_loss) / 4
83 | outputs = (total_loss,) + outputs
84 |
85 | return outputs
86 |
--------------------------------------------------------------------------------
/src/models/albert.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3 | # Copyright 2020 Guillaume Becquin.
4 | # MODIFIED FOR CAUSE EFFECT EXTRACTION
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 |
19 | from torch import nn
20 | from transformers import AlbertModel, AlbertPreTrainedModel
21 |
22 |
23 | class AlbertForCauseEffect(AlbertPreTrainedModel):
24 | def __init__(self, config):
25 | super(AlbertForCauseEffect, self).__init__(config)
26 |
27 | self.albert = AlbertModel(config)
28 | self.cause_outputs = nn.Linear(config.hidden_size, config.num_labels)
29 | self.effect_outputs = nn.Linear(config.hidden_size, config.num_labels)
30 | assert config.num_labels == 2
31 | self.init_weights()
32 |
33 | def forward(
34 | self,
35 | input_ids=None,
36 | attention_mask=None,
37 | token_type_ids=None,
38 | position_ids=None,
39 | head_mask=None,
40 | inputs_embeds=None,
41 | start_cause_positions=None,
42 | end_cause_positions=None,
43 | start_effect_positions=None,
44 | end_effect_positions=None,
45 | ):
46 | albert_output = self.albert(
47 | input_ids=input_ids,
48 | attention_mask=attention_mask,
49 | token_type_ids=token_type_ids,
50 | position_ids=position_ids,
51 | head_mask=head_mask,
52 | inputs_embeds=inputs_embeds
53 | )
54 | hidden_states = albert_output[0] # (bs, max_query_len, dim)
55 | cause_logits = self.cause_outputs(hidden_states) # (bs, max_query_len, 2)
56 | effect_logits = self.effect_outputs(hidden_states) # (bs, max_query_len, 2)
57 | start_cause_logits, end_cause_logits = cause_logits.split(1, dim=-1)
58 | start_effect_logits, end_effect_logits = effect_logits.split(1, dim=-1)
59 |
60 | start_cause_logits = start_cause_logits.squeeze(-1) # (bs, max_query_len)
61 | end_cause_logits = end_cause_logits.squeeze(-1) # (bs, max_query_len)
62 | start_effect_logits = start_effect_logits.squeeze(-1) # (bs, max_query_len)
63 | end_effect_logits = end_effect_logits.squeeze(-1) # (bs, max_query_len)
64 |
65 | outputs = (start_cause_logits, end_cause_logits, start_effect_logits, end_effect_logits,) \
66 | + albert_output[2:]
67 | if start_cause_positions is not None \
68 | and end_cause_positions is not None \
69 | and start_effect_positions is not None \
70 | and end_effect_positions is not None:
71 | # sometimes the start/end positions are outside our model inputs, we ignore these terms
72 | ignored_index = start_cause_logits.size(1)
73 | start_cause_positions.clamp_(0, ignored_index)
74 | end_cause_positions.clamp_(0, ignored_index)
75 | start_effect_positions.clamp_(0, ignored_index)
76 | end_effect_positions.clamp_(0, ignored_index)
77 |
78 | loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
79 | start_cause_loss = loss_fct(start_cause_logits, start_cause_positions)
80 | end_cause_loss = loss_fct(end_cause_logits, end_cause_positions)
81 | start_effect_loss = loss_fct(start_effect_logits, start_effect_positions)
82 | end_effect_loss = loss_fct(end_effect_logits, end_effect_positions)
83 | total_loss = (start_cause_loss + end_cause_loss + start_effect_loss + end_effect_loss) / 4
84 | outputs = (total_loss,) + outputs
85 |
86 | return outputs
87 |
--------------------------------------------------------------------------------
/src/models/distilbert.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3 | # Copyright 2020 Guillaume Becquin.
4 | # MODIFIED FOR CAUSE EFFECT EXTRACTION
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 | from torch import nn
19 | from transformers import DistilBertPreTrainedModel, DistilBertModel
20 |
21 |
22 | class DistilBertForCauseEffect(DistilBertPreTrainedModel):
23 | def __init__(self, config):
24 | super().__init__(config)
25 |
26 | self.distilbert = DistilBertModel(config)
27 | self.cause_outputs = nn.Linear(config.dim, config.num_labels)
28 | self.effect_outputs = nn.Linear(config.dim, config.num_labels)
29 | assert config.num_labels == 2
30 | self.dropout = nn.Dropout(config.qa_dropout)
31 | self.init_weights()
32 |
33 | def forward(
34 | self,
35 | input_ids=None,
36 | attention_mask=None,
37 | head_mask=None,
38 | inputs_embeds=None,
39 | start_cause_positions=None,
40 | end_cause_positions=None,
41 | start_effect_positions=None,
42 | end_effect_positions=None,
43 | ):
44 | distilbert_output = self.distilbert(
45 | input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds
46 | )
47 | hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
48 |
49 | hidden_states = self.dropout(hidden_states) # (bs, max_query_len, dim)
50 | cause_logits = self.cause_outputs(hidden_states) # (bs, max_query_len, 2)
51 | effect_logits = self.effect_outputs(hidden_states) # (bs, max_query_len, 2)
52 | start_cause_logits, end_cause_logits = cause_logits.split(1, dim=-1)
53 | start_effect_logits, end_effect_logits = effect_logits.split(1, dim=-1)
54 |
55 | start_cause_logits = start_cause_logits.squeeze(-1) # (bs, max_query_len)
56 | end_cause_logits = end_cause_logits.squeeze(-1) # (bs, max_query_len)
57 | start_effect_logits = start_effect_logits.squeeze(-1) # (bs, max_query_len)
58 | end_effect_logits = end_effect_logits.squeeze(-1) # (bs, max_query_len)
59 |
60 | outputs = (start_cause_logits, end_cause_logits, start_effect_logits, end_effect_logits,) \
61 | + distilbert_output[1:]
62 | if start_cause_positions is not None \
63 | and end_cause_positions is not None \
64 | and start_effect_positions is not None \
65 | and end_effect_positions is not None:
66 | # sometimes the start/end positions are outside our model inputs, we ignore these terms
67 | ignored_index = start_cause_logits.size(1)
68 | start_cause_positions.clamp_(0, ignored_index)
69 | end_cause_positions.clamp_(0, ignored_index)
70 | start_effect_positions.clamp_(0, ignored_index)
71 | end_effect_positions.clamp_(0, ignored_index)
72 |
73 | loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
74 | start_cause_loss = loss_fct(start_cause_logits, start_cause_positions)
75 | end_cause_loss = loss_fct(end_cause_logits, end_cause_positions)
76 | start_effect_loss = loss_fct(start_effect_logits, start_effect_positions)
77 | end_effect_loss = loss_fct(end_effect_logits, end_effect_positions)
78 | total_loss = (start_cause_loss + end_cause_loss + start_effect_loss + end_effect_loss) / 4
79 | outputs = (total_loss,) + outputs
80 |
81 | return outputs # (loss), start_cause_logits, end_cause_logits, start_effect_logits, end_effect_logits(hidden_states), (attentions)
82 |
--------------------------------------------------------------------------------
/src/models/bert.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3 | # Copyright 2020 Guillaume Becquin.
4 | # MODIFIED FOR CAUSE EFFECT EXTRACTION
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 | from torch import nn
19 | from transformers import BertModel, BertPreTrainedModel
20 |
21 |
22 | class BertForCauseEffect(BertPreTrainedModel):
23 | def __init__(self, config):
24 | super(BertForCauseEffect, self).__init__(config)
25 |
26 | self.bert = BertModel(config)
27 | self.cause_outputs = nn.Linear(config.hidden_size, config.num_labels)
28 | self.effect_outputs = nn.Linear(config.hidden_size, config.num_labels)
29 | assert config.num_labels == 2
30 | self.init_weights()
31 |
32 | def forward(
33 | self,
34 | input_ids=None,
35 | attention_mask=None,
36 | token_type_ids=None,
37 | position_ids=None,
38 | head_mask=None,
39 | inputs_embeds=None,
40 | start_cause_positions=None,
41 | end_cause_positions=None,
42 | start_effect_positions=None,
43 | end_effect_positions=None,
44 | ):
45 | bert_output = self.bert(
46 | input_ids=input_ids,
47 | attention_mask=attention_mask,
48 | token_type_ids=token_type_ids,
49 | position_ids=position_ids,
50 | head_mask=head_mask,
51 | inputs_embeds=inputs_embeds
52 | )
53 | hidden_states = bert_output[0] # (bs, max_query_len, dim)
54 | cause_logits = self.cause_outputs(hidden_states) # (bs, max_query_len, 2)
55 | effect_logits = self.effect_outputs(hidden_states) # (bs, max_query_len, 2)
56 | start_cause_logits, end_cause_logits = cause_logits.split(1, dim=-1)
57 | start_effect_logits, end_effect_logits = effect_logits.split(1, dim=-1)
58 |
59 | start_cause_logits = start_cause_logits.squeeze(-1) # (bs, max_query_len)
60 | end_cause_logits = end_cause_logits.squeeze(-1) # (bs, max_query_len)
61 | start_effect_logits = start_effect_logits.squeeze(-1) # (bs, max_query_len)
62 | end_effect_logits = end_effect_logits.squeeze(-1) # (bs, max_query_len)
63 |
64 | outputs = (start_cause_logits, end_cause_logits, start_effect_logits, end_effect_logits,) \
65 | + bert_output[2:]
66 | if start_cause_positions is not None \
67 | and end_cause_positions is not None \
68 | and start_effect_positions is not None \
69 | and end_effect_positions is not None:
70 | # sometimes the start/end positions are outside our model inputs, we ignore these terms
71 | ignored_index = start_cause_logits.size(1)
72 | start_cause_positions.clamp_(0, ignored_index)
73 | end_cause_positions.clamp_(0, ignored_index)
74 | start_effect_positions.clamp_(0, ignored_index)
75 | end_effect_positions.clamp_(0, ignored_index)
76 |
77 | loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
78 | start_cause_loss = loss_fct(start_cause_logits, start_cause_positions)
79 | end_cause_loss = loss_fct(end_cause_logits, end_cause_positions)
80 | start_effect_loss = loss_fct(start_effect_logits, start_effect_positions)
81 | end_effect_loss = loss_fct(end_effect_logits, end_effect_positions)
82 | total_loss = (start_cause_loss + end_cause_loss + start_effect_loss + end_effect_loss) / 4
83 | outputs = (total_loss,) + outputs
84 |
85 | return outputs # (loss), start_cause_logits, end_cause_logits, start_effect_logits, end_effect_logits(hidden_states), (attentions)
86 |
--------------------------------------------------------------------------------
/src/models/xlnet.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3 | # Copyright 2020 Guillaume Becquin.
4 | # MODIFIED FOR CAUSE EFFECT EXTRACTION
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 | from torch import nn
19 | from torch.nn import CrossEntropyLoss
20 | from transformers import XLNetPreTrainedModel, XLNetModel
21 |
22 |
23 | class XLNetForCauseEffect(XLNetPreTrainedModel):
24 | def __init__(self, config):
25 | super().__init__(config)
26 | self.num_labels = config.num_labels
27 |
28 | self.transformer = XLNetModel(config)
29 | self.cause_outputs = nn.Linear(config.hidden_size, config.num_labels)
30 | self.effect_outputs = nn.Linear(config.hidden_size, config.num_labels)
31 | self.init_weights()
32 |
33 | def forward(
34 | self,
35 | input_ids=None,
36 | attention_mask=None,
37 | mems=None,
38 | perm_mask=None,
39 | target_mapping=None,
40 | token_type_ids=None,
41 | input_mask=None,
42 | head_mask=None,
43 | inputs_embeds=None,
44 | use_cache=True,
45 | start_cause_positions=None,
46 | end_cause_positions=None,
47 | start_effect_positions=None,
48 | end_effect_positions=None,
49 | ):
50 |
51 | outputs = self.transformer(
52 | input_ids,
53 | attention_mask=attention_mask,
54 | mems=mems,
55 | perm_mask=perm_mask,
56 | target_mapping=target_mapping,
57 | token_type_ids=token_type_ids,
58 | input_mask=input_mask,
59 | head_mask=head_mask,
60 | inputs_embeds=inputs_embeds,
61 | use_cache=use_cache,
62 | )
63 |
64 | sequence_output = outputs[0]
65 |
66 | cause_logits = self.cause_outputs(sequence_output)
67 | start_cause_logits, end_cause_logits = cause_logits.split(1, dim=-1)
68 | start_cause_logits = start_cause_logits.squeeze(-1)
69 | end_cause_logits = end_cause_logits.squeeze(-1)
70 |
71 | effect_logits = self.effect_outputs(sequence_output)
72 | start_effect_logits, end_effect_logits = effect_logits.split(1, dim=-1)
73 | start_effect_logits = start_effect_logits.squeeze(-1)
74 | end_effect_logits = end_effect_logits.squeeze(-1)
75 |
76 | outputs = (start_cause_logits, end_cause_logits, start_effect_logits, end_effect_logits,) + outputs[2:]
77 | if start_cause_positions is not None \
78 | and end_cause_positions is not None \
79 | and start_effect_positions is not None \
80 | and end_effect_positions is not None:
81 | # sometimes the start/end positions are outside our model inputs, we ignore these terms
82 | ignored_index = start_cause_logits.size(1)
83 | start_cause_positions.clamp_(0, ignored_index)
84 | end_cause_positions.clamp_(0, ignored_index)
85 | start_effect_positions.clamp_(0, ignored_index)
86 | end_effect_positions.clamp_(0, ignored_index)
87 |
88 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
89 | start_cause_loss = loss_fct(start_cause_logits, start_cause_positions)
90 | end_cause_loss = loss_fct(end_cause_logits, end_cause_positions)
91 | start_effect_loss = loss_fct(start_effect_logits, start_effect_positions)
92 | end_effect_loss = loss_fct(end_effect_logits, end_effect_positions)
93 | total_loss = (start_cause_loss + end_cause_loss + start_effect_loss + end_effect_loss) / 4
94 | outputs = (total_loss,) + outputs
95 |
96 | return outputs
97 |
--------------------------------------------------------------------------------
/src/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020, Guillaume Becquin
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from enum import Enum
16 | from typing import Callable
17 |
18 | import torch
19 | from transformers import AdamW, get_cosine_schedule_with_warmup, AlbertTokenizer, XLNetTokenizer, RobertaTokenizer, \
20 | BertTokenizer, DistilBertTokenizer
21 |
22 | from src.models.albert import AlbertForCauseEffect
23 | from src.models.bert import BertForCauseEffect
24 | from src.models.distilbert import DistilBertForCauseEffect
25 | from src.models.roberta import RoBERTaForCauseEffect
26 | from src.models.xlnet import XLNetForCauseEffect
27 |
28 |
29 | class ModelConfigurations(Enum):
30 | BertBase = ('bert', 'bert-base-cased', False)
31 | BertLarge = ('bert', 'bert-large-cased', False)
32 | BertSquad = ('bert', 'deepset/bert-base-cased-squad2', False)
33 | BertSquad2 = ('bert', 'deepset/bert-large-uncased-whole-word-masking-squad2', True)
34 | DistilBert = ('distilbert', 'distilbert-base-uncased', True)
35 | DistilBertSquad = ('distilbert', 'distilbert-base-uncased-distilled-squad', True)
36 | RoBERTaSquad = ('roberta', 'deepset/roberta-base-squad2', False)
37 | RoBERTaSquadLarge = ('roberta', 'ahotrod/roberta_large_squad2', False)
38 | RoBERTa = ('roberta', 'roberta-base', False)
39 | RoBERTaLarge = ('roberta', 'roberta-large', False)
40 | XLNetBase = ('xlnet', 'xlnet-base-cased', False)
41 | AlbertSquad = ('albert', 'twmkn9/albert-base-v2-squad2', True)
42 |
43 |
44 | model_tokenizer_mapping = {
45 | 'distilbert': (DistilBertForCauseEffect, DistilBertTokenizer),
46 | 'bert': (BertForCauseEffect, BertTokenizer),
47 | 'roberta': (RoBERTaForCauseEffect, RobertaTokenizer),
48 | 'xlnet': (XLNetForCauseEffect, XLNetTokenizer),
49 | 'albert': (AlbertForCauseEffect, AlbertTokenizer),
50 | }
51 |
52 |
53 | class RunConfig:
54 | def __init__(self,
55 | do_train: bool = False,
56 | do_eval: bool = True,
57 | do_test: bool = False,
58 | max_seq_length: int = 384,
59 | doc_stride: int = 128,
60 | train_batch_size: int = 4,
61 | gradient_accumulation_steps: int = 3,
62 | warmup_steps: int = 50,
63 | learning_rate: float = 3e-5,
64 | differential_lr_ratio: float = 1.0,
65 | max_grad_norm: float = 1.0,
66 | adam_epsilon: float = 1e-8,
67 | num_train_epochs: int = 5,
68 | save_model: bool = True,
69 | weight_decay: float = 0.0,
70 | optimizer_class: torch.optim.Optimizer = AdamW,
71 | scheduler_function: Callable = get_cosine_schedule_with_warmup,
72 | evaluate_during_training: bool = True,
73 | eval_batch_size: int = 8,
74 | n_best_size: int = 5,
75 | max_answer_length: int = 300,
76 | sentence_boundary_heuristic: bool = True,
77 | full_sentence_heuristic: bool = True,
78 | shared_sentence_heuristic: bool = False,
79 | top_n_sentences: bool = True):
80 | self.do_train = do_train
81 | self.do_eval = do_eval
82 | self.do_test = do_test
83 | self.max_seq_length = max_seq_length
84 | self.doc_stride = doc_stride
85 | self.train_batch_size = train_batch_size
86 | self.gradient_accumulation_steps = gradient_accumulation_steps
87 | self.warmup_steps = warmup_steps
88 | self.learning_rate = learning_rate
89 | self.differential_lr_ratio = differential_lr_ratio
90 | self.max_grad_norm = max_grad_norm
91 | self.adam_epsilon = adam_epsilon
92 | self.num_train_epochs = num_train_epochs
93 | self.save_model = save_model
94 | self.weight_decay = weight_decay
95 | self.optimizer_class = optimizer_class
96 | self.scheduler_function = scheduler_function
97 | self.evaluate_during_training = evaluate_during_training
98 | self.eval_batch_size = eval_batch_size
99 | self.n_best_size = n_best_size
100 | self.max_answer_length = max_answer_length
101 | self.sentence_boundary_heuristic = sentence_boundary_heuristic
102 | self.full_sentence_heuristic = full_sentence_heuristic
103 | self.shared_sentence_heuristic = shared_sentence_heuristic
104 | self.top_n_sentences = top_n_sentences
105 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3 | # Copyright 2020 Guillaume Becquin. All rights reserved.
4 | # MODIFIED FOR CAUSE EFFECT EXTRACTION
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 |
19 | import json
20 | import logging
21 | import os
22 | import sys
23 | from pathlib import Path
24 | import torch
25 | from src.config import RunConfig, ModelConfigurations, model_tokenizer_mapping
26 | from src.evaluation import evaluate, predict
27 | from src.logging import initialize_log_dict
28 | from src.preprocessing import load_examples
29 | from src.training import train
30 | import argparse
31 |
32 | logging.basicConfig(level=logging.INFO)
33 | logger = logging.getLogger(__name__)
34 |
35 | if __name__ == '__main__':
36 |
37 | parser = argparse.ArgumentParser()
38 | parser.add_argument(
39 | "--train",
40 | default=False,
41 | required=False,
42 | action="store_true",
43 | help="Flag to specify if the model should be trained",
44 | )
45 |
46 | parser.add_argument(
47 | "--eval",
48 | default=False,
49 | required=False,
50 | action="store_true",
51 | help="Flag to specify if the model should be evaluated",
52 | )
53 |
54 | parser.add_argument(
55 | "--test",
56 | default=False,
57 | required=False,
58 | action="store_true",
59 | help="Flag to specify if the model should generate predictions on the train file",
60 | )
61 | args = parser.parse_args()
62 | assert args.train or args.eval or args.test, \
63 | "At least one task needs to be selected by passing --train, --eval or --test"
64 |
65 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66 |
67 | model_config = ModelConfigurations.RoBERTaSquadLarge
68 | run_config = RunConfig()
69 | run_config.do_train = args.train
70 | run_config.do_eval = args.eval
71 | run_config.do_test = args.test
72 |
73 | RUN_NAME = 'model_run'
74 |
75 | (MODEL_TYPE, MODEL_NAME_OR_PATH, DO_LOWER_CASE) = model_config.value
76 |
77 | fincausal_data_path = Path(os.environ.get('FINCAUSAL_DATA_PATH',
78 | os.path.dirname(os.path.realpath(sys.argv[0])) + './data'))
79 | fincausal_output_path = Path(os.environ.get('FINCAUSAL_OUTPUT_PATH',
80 | os.path.dirname(os.path.realpath(sys.argv[0])) + './output'))
81 |
82 | TRAIN_FILE = fincausal_data_path / "fnp2020-train.csv"
83 | EVAL_FILE = fincausal_data_path / "fnp2020-eval.csv"
84 | TEST_FILE = fincausal_data_path / "task2.csv"
85 |
86 | if RUN_NAME:
87 | OUTPUT_DIR = fincausal_output_path / (MODEL_NAME_OR_PATH + '_' + RUN_NAME)
88 | else:
89 | OUTPUT_DIR = fincausal_output_path / MODEL_NAME_OR_PATH
90 |
91 | model_class, tokenizer_class = model_tokenizer_mapping[MODEL_TYPE]
92 | log_file = initialize_log_dict(model_config=model_config,
93 | run_config=run_config,
94 | model_tokenizer_mapping=model_tokenizer_mapping)
95 |
96 | # Training
97 | if run_config.do_train:
98 |
99 | tokenizer = tokenizer_class.from_pretrained(MODEL_NAME_OR_PATH,
100 | do_lower_case=DO_LOWER_CASE,
101 | cache_dir=OUTPUT_DIR)
102 | model = model_class.from_pretrained(MODEL_NAME_OR_PATH).to(device)
103 |
104 | train_dataset = load_examples(file_path=TRAIN_FILE,
105 | tokenizer=tokenizer,
106 | output_examples=False,
107 | run_config=run_config)
108 |
109 | train(train_dataset=train_dataset,
110 | model=model,
111 | tokenizer=tokenizer,
112 | model_type=MODEL_TYPE,
113 | output_dir=OUTPUT_DIR,
114 | predict_file=EVAL_FILE,
115 | device=device,
116 | log_file=log_file,
117 | run_config=run_config
118 | )
119 | if not OUTPUT_DIR.is_dir():
120 | OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
121 | if run_config.save_model:
122 | model_to_save = model.module if hasattr(model, "module") else model
123 | model_to_save.save_pretrained(OUTPUT_DIR)
124 | tokenizer.save_pretrained(OUTPUT_DIR)
125 | logger.info("Saving final model to %s", OUTPUT_DIR)
126 | logger.info("Saving log file to %s", OUTPUT_DIR)
127 | with open(os.path.join(OUTPUT_DIR, "logs.json"), 'w') as f:
128 | json.dump(log_file, f, indent=4)
129 |
130 | if run_config.do_eval:
131 | tokenizer = tokenizer_class.from_pretrained(str(OUTPUT_DIR), do_lower_case=DO_LOWER_CASE)
132 | model = model_class.from_pretrained(str(OUTPUT_DIR)).to(device)
133 |
134 | result = evaluate(model=model,
135 | tokenizer=tokenizer,
136 | device=device,
137 | file_path=EVAL_FILE,
138 | model_type=MODEL_TYPE,
139 | output_dir=OUTPUT_DIR,
140 | run_config=run_config
141 | )
142 |
143 | print("done")
144 |
145 | if run_config.do_test:
146 | tokenizer = tokenizer_class.from_pretrained(str(OUTPUT_DIR), do_lower_case=DO_LOWER_CASE)
147 | model = model_class.from_pretrained(str(OUTPUT_DIR)).to(device)
148 |
149 | result = predict(model=model,
150 | tokenizer=tokenizer,
151 | device=device,
152 | file_path=TEST_FILE,
153 | model_type=MODEL_TYPE,
154 | output_dir=OUTPUT_DIR,
155 | run_config=run_config
156 | )
157 |
158 | print("done")
159 |
--------------------------------------------------------------------------------
/src/data.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3 | # Copyright 2020 Guillaume Becquin.
4 | # MODIFIED FOR CAUSE EFFECT EXTRACTION
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 | import unicodedata
19 | from typing import Optional, List
20 |
21 | import pandas as pd
22 |
23 |
24 | class FinCausalExample:
25 |
26 | def __init__(
27 | self,
28 | example_id: str,
29 | context_text: str,
30 | cause_text: str,
31 | effect_text: str,
32 | offset_sentence_2: int,
33 | offset_sentence_3: int,
34 | cause_start_position_character: Optional[int],
35 | cause_end_position_character: Optional[int],
36 | effect_start_position_character: Optional[int],
37 | effect_end_position_character: Optional[int]
38 | ):
39 |
40 | self.example_id = example_id
41 | self.context_text = context_text
42 | self.cause_text = cause_text
43 | self.effect_text = effect_text
44 |
45 | self.start_cause_position, self.end_cause_position = 0, 0
46 | self.start_effect_position, self.end_effect_position = 0, 0
47 |
48 | doc_tokens: List[str] = []
49 | char_to_word_offset: List[int] = []
50 | word_to_char_mapping: List[int] = []
51 | prev_is_whitespace = True
52 |
53 | # Split on whitespace so that different tokens may be attributed to their original position.
54 | for char_index, c in enumerate(self.context_text):
55 | if _is_whitespace(c):
56 | prev_is_whitespace = True
57 | else:
58 | if prev_is_whitespace or _is_punctuation(c):
59 | doc_tokens.append(c)
60 | word_to_char_mapping.append(char_index)
61 | else:
62 | doc_tokens[-1] += c
63 | if _is_punctuation(c):
64 | prev_is_whitespace = True
65 | else:
66 | prev_is_whitespace = False
67 | char_to_word_offset.append(len(doc_tokens) - 1)
68 |
69 | self.doc_tokens = doc_tokens
70 | self.char_to_word_offset = char_to_word_offset
71 |
72 | # Start and end positions only has a value during evaluation.
73 | if cause_start_position_character is not None:
74 | assert (cause_start_position_character + len(cause_text) == cause_end_position_character)
75 | self.start_cause_position = char_to_word_offset[cause_start_position_character]
76 | self.end_cause_position = char_to_word_offset[
77 | min(cause_start_position_character + len(cause_text) - 1, len(char_to_word_offset) - 1)
78 | ]
79 | if effect_start_position_character is not None:
80 | self.start_effect_position = char_to_word_offset[effect_start_position_character]
81 | assert (effect_start_position_character + len(effect_text) == effect_end_position_character)
82 | self.end_effect_position = char_to_word_offset[
83 | min(effect_start_position_character + len(effect_text) - 1, len(char_to_word_offset) - 1)
84 | ]
85 |
86 | if pd.notna(offset_sentence_2):
87 | self.offset_sentence_2 = char_to_word_offset[int(offset_sentence_2)]
88 | else:
89 | self.offset_sentence_2 = 0
90 | if pd.notna(offset_sentence_3):
91 | self.offset_sentence_3 = char_to_word_offset[int(offset_sentence_3)]
92 | else:
93 | self.offset_sentence_3 = 0
94 |
95 | self.word_to_char_mapping = word_to_char_mapping
96 |
97 |
98 | class FinCausalFeatures:
99 |
100 | def __init__(
101 | self,
102 | input_ids,
103 | attention_mask,
104 | token_type_ids,
105 | cls_index,
106 | p_mask,
107 | example_orig_index,
108 | example_index,
109 | unique_id,
110 | paragraph_len,
111 | token_is_max_context,
112 | tokens,
113 | token_to_orig_map,
114 | cause_start_position,
115 | cause_end_position,
116 | effect_start_position,
117 | effect_end_position,
118 | sentence_2_offset,
119 | sentence_3_offset,
120 | is_impossible,
121 | ):
122 | self.input_ids = input_ids
123 | self.attention_mask = attention_mask
124 | self.token_type_ids = token_type_ids
125 | self.cls_index = cls_index
126 | self.p_mask = p_mask
127 |
128 | self.example_orig_index = example_orig_index
129 | self.example_index = example_index
130 | self.unique_id = unique_id
131 | self.paragraph_len = paragraph_len
132 | self.token_is_max_context = token_is_max_context
133 | self.tokens = tokens
134 | self.token_to_orig_map = token_to_orig_map
135 |
136 | self.cause_start_position = cause_start_position
137 | self.cause_end_position = cause_end_position
138 | self.effect_start_position = effect_start_position
139 | self.effect_end_position = effect_end_position
140 |
141 | self.sentence_2_offset = sentence_2_offset
142 | self.sentence_3_offset = sentence_3_offset
143 |
144 | self.is_impossible = is_impossible
145 |
146 |
147 | class FinCausalResult:
148 | def __init__(self, unique_id, start_cause_logits, end_cause_logits, start_effect_logits, end_effect_logits,
149 | start_top_index=None, end_top_index=None, cls_logits=None):
150 | self.start_cause_logits = start_cause_logits
151 | self.end_cause_logits = end_cause_logits
152 | self.start_effect_logits = start_effect_logits
153 | self.end_effect_logits = end_effect_logits
154 | self.unique_id = unique_id
155 |
156 | if start_top_index:
157 | self.start_top_index = start_top_index
158 | self.end_top_index = end_top_index
159 | self.cls_logits = cls_logits
160 |
161 |
162 | def _is_whitespace(char: str) -> bool:
163 | if char == " " or char == "\t" or char == "\r" or char == "\n" or char == '\xa0' or ord(char) == 0x202F:
164 | return True
165 | return False
166 |
167 |
168 | def _is_punctuation(char: str) -> bool:
169 | cp = ord(char)
170 | if (33 <= cp <= 47) or (58 <= cp <= 64) or (91 <= cp <= 96) or (123 <= cp <= 126):
171 | return True
172 | cat = unicodedata.category(char)
173 | if cat.startswith("P"):
174 | return True
175 | return False
176 |
--------------------------------------------------------------------------------
/src/training.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3 | # Copyright 2020 Guillaume Becquin.
4 | # MODIFIED FOR CAUSE EFFECT EXTRACTION
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 | import logging
19 | from pathlib import Path
20 | from typing import Dict, Union
21 |
22 | import torch
23 | from torch.nn import Module
24 | from torch.utils.data import RandomSampler, DataLoader, TensorDataset
25 | from tqdm import trange, tqdm
26 | from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
27 |
28 | from .config import RunConfig
29 | from .evaluation import evaluate
30 |
31 | logger = logging.getLogger(__name__)
32 |
33 |
34 | def train(train_dataset: TensorDataset,
35 | model: Module,
36 | tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
37 | model_type: str,
38 | output_dir: Path,
39 | predict_file: Path,
40 | log_file: Dict,
41 | device: torch.device,
42 | run_config: RunConfig):
43 | train_sampler = RandomSampler(train_dataset)
44 | train_dataloader = DataLoader(train_dataset,
45 | sampler=train_sampler,
46 | batch_size=run_config.train_batch_size)
47 |
48 | t_total = len(train_dataloader) // run_config.gradient_accumulation_steps * run_config.num_train_epochs
49 |
50 | # Define Optimizer and learning rates / decay
51 | no_decay = ["bias", "LayerNorm.weight"]
52 | no_scaled_lr = ["cause_outputs", "effect_outputs"]
53 | if run_config.differential_lr_ratio == 0:
54 | differential_lr_ratio = 1.0
55 | else:
56 | differential_lr_ratio = run_config.differential_lr_ratio
57 | assert differential_lr_ratio <= 1, "ratio for language model layers should be <= 1"
58 | optimizer_grouped_parameters = [
59 | {
60 | "params": [p for n, p in model.named_parameters() if (not any(nd in n for nd in no_decay)
61 | and not any(nlr in n for nlr in no_scaled_lr))],
62 | 'lr': run_config.learning_rate * differential_lr_ratio,
63 | "weight_decay": run_config.weight_decay,
64 | },
65 | {
66 | "params": [p for n, p in model.named_parameters() if (not any(nd in n for nd in no_decay)
67 | and any(nlr in n for nlr in no_scaled_lr))],
68 | 'lr': run_config.learning_rate,
69 | "weight_decay": run_config.weight_decay,
70 | },
71 | {
72 | "params": [p for n, p in model.named_parameters() if (any(nd in n for nd in no_decay)
73 | and not any(nlr in n for nlr in no_scaled_lr))],
74 | 'lr': run_config.learning_rate * differential_lr_ratio,
75 | "weight_decay": 0.0
76 | },
77 | {
78 | "params": [p for n, p in model.named_parameters() if (any(nd in n for nd in no_decay)
79 | and any(nlr in n for nlr in no_scaled_lr))],
80 | 'lr': run_config.learning_rate,
81 | "weight_decay": 0.0
82 | },
83 | ]
84 | optimizer = run_config.optimizer_class(optimizer_grouped_parameters,
85 | lr=run_config.learning_rate,
86 | eps=run_config.adam_epsilon)
87 |
88 | # Define Scheduler
89 | try:
90 | scheduler = run_config.scheduler_function(optimizer,
91 | num_warmup_steps=run_config.warmup_steps,
92 | num_training_steps=t_total)
93 | except ValueError:
94 | scheduler = run_config.scheduler_function(optimizer,
95 | num_warmup_steps=run_config.warmup_steps)
96 |
97 | # Start training
98 | logger.info("***** Running training *****")
99 | logger.info(" Num examples = %d", len(train_dataset))
100 | logger.info(" Num Epochs = %d", run_config.num_train_epochs)
101 | logger.info(
102 | " Total train batch size (w. parallel, distributed & accumulation) = %d",
103 | run_config.train_batch_size * run_config.gradient_accumulation_steps
104 | )
105 | logger.info(" Gradient Accumulation steps = %d", run_config.gradient_accumulation_steps)
106 | logger.info(" Total optimization steps = %d", t_total)
107 |
108 | global_step = 1
109 | epochs_trained = 0
110 |
111 | tr_loss, logging_loss = 0.0, 0.0
112 | model.zero_grad()
113 | train_iterator = trange(epochs_trained, int(run_config.num_train_epochs), desc="Epoch")
114 |
115 | for _ in train_iterator:
116 | epoch_iterator = tqdm(train_dataloader, desc=f"Iteration Loss: {tr_loss / global_step}", position=0, leave=True)
117 | for step, batch in enumerate(epoch_iterator):
118 | epoch_iterator.set_description(f"Iteration Loss: {tr_loss / global_step}")
119 |
120 | model.train()
121 | batch = tuple(t.to(device) for t in batch)
122 |
123 | inputs = {
124 | "input_ids": batch[0],
125 | "attention_mask": batch[1],
126 | "token_type_ids": batch[2],
127 | "start_cause_positions": batch[3],
128 | "end_cause_positions": batch[4],
129 | "start_effect_positions": batch[5],
130 | "end_effect_positions": batch[6],
131 | }
132 |
133 | if model_type in ["xlm", "roberta", "distilbert", "camembert"]:
134 | del inputs["token_type_ids"]
135 |
136 | outputs = model(**inputs)
137 | loss = outputs[0]
138 |
139 | if run_config.gradient_accumulation_steps > 1:
140 | loss = loss / run_config.gradient_accumulation_steps
141 |
142 | loss.backward()
143 |
144 | tr_loss += loss.item()
145 | if (step + 1) % run_config.gradient_accumulation_steps == 0:
146 | torch.nn.utils.clip_grad_norm_(model.parameters(), run_config.max_grad_norm)
147 | optimizer.step()
148 | scheduler.step()
149 | model.zero_grad()
150 | global_step += 1
151 |
152 | # Evaluate model and log metrics
153 | if run_config.evaluate_during_training:
154 | metrics = evaluate(model=model,
155 | tokenizer=tokenizer,
156 | device=device,
157 | file_path=predict_file,
158 | model_type=model_type,
159 | output_dir=output_dir,
160 | run_config=run_config)
161 | log_file[f'step_{global_step}'] = metrics
162 |
163 | _output_dir = output_dir / "checkpoint-{}".format(global_step)
164 | if not _output_dir.is_dir():
165 | _output_dir.mkdir(parents=True, exist_ok=True)
166 |
167 | model.save_pretrained(_output_dir)
168 | tokenizer.save_pretrained(_output_dir)
169 | logger.info("Best F1 score: saving model checkpoint to %s", _output_dir)
170 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/src/preprocessing.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3 | # Copyright 2020, Guillaume Becquin
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import logging
18 | import multiprocessing
19 | from collections import UserDict
20 | from functools import partial
21 | from pathlib import Path
22 | from typing import Union, List, Tuple
23 |
24 | import spacy
25 | from pysbd.utils import PySBDFactory
26 |
27 | import numpy as np
28 | import pandas as pd
29 | import torch
30 | from torch.utils.data import TensorDataset
31 | from tqdm import tqdm
32 | from transformers import PreTrainedTokenizer, BatchEncoding, PreTrainedTokenizerFast
33 | from transformers.tokenization_bert import whitespace_tokenize
34 |
35 | from .config import RunConfig
36 | from .data import FinCausalExample, FinCausalFeatures, _is_punctuation
37 |
38 | logger = logging.getLogger(__name__)
39 |
40 |
41 | def load_examples(file_path: Path,
42 | tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
43 | run_config: RunConfig,
44 | output_examples: bool = True,
45 | evaluate: bool = False) -> \
46 | Union[Tuple[TensorDataset, List[FinCausalExample], List[FinCausalFeatures]],
47 | TensorDataset]:
48 | processor = FinCausalProcessor()
49 | examples = processor.get_examples(file_path)
50 |
51 | features, dataset = fincausal_convert_examples_to_features(
52 | examples=examples,
53 | tokenizer=tokenizer,
54 | max_seq_length=run_config.max_seq_length,
55 | doc_stride=run_config.doc_stride,
56 | is_training=not evaluate,
57 | return_dataset="pt",
58 | threads=1,
59 | )
60 |
61 | if output_examples:
62 | return dataset, examples, features
63 | return dataset
64 |
65 |
66 | class FinCausalProcessor:
67 |
68 | def get_examples(self, file_path: Path) -> List[FinCausalExample]:
69 | input_data = pd.read_csv(file_path, index_col=0, delimiter=';', header=0, skipinitialspace=True)
70 | input_data.columns = [col_name.strip() for col_name in input_data.columns]
71 | return self._create_examples(input_data)
72 |
73 | @staticmethod
74 | def _create_examples(input_data: pd.DataFrame) -> List[FinCausalExample]:
75 |
76 | nlp = spacy.blank('en')
77 | nlp.add_pipe(PySBDFactory(nlp))
78 |
79 | examples = []
80 | for entry in tqdm(input_data.itertuples()):
81 | context_text = entry.Text
82 | example_id = entry.Index
83 | if len(entry) > 2:
84 | cause_text = entry.Cause
85 | effect_text = entry.Effect
86 | cause_start_position_character = entry.Cause_Start
87 | cause_end_position_character = entry.Cause_End
88 | effect_start_position_character = entry.Effect_Start
89 | effect_end_position_character = entry.Effect_End
90 | else:
91 | cause_text = ""
92 | effect_text = ""
93 | cause_start_position_character = 0
94 | cause_end_position_character = 0
95 | effect_start_position_character = 0
96 | effect_end_position_character = 0
97 |
98 | doc = nlp(entry.Text)
99 | sentences = list(doc.sents)
100 | offset_sentence_2 = np.nan
101 | offset_sentence_3 = np.nan
102 | if len(sentences) > 1:
103 | offset_sentence_2 = sentences[0].end_char
104 | if len(sentences) > 2:
105 | offset_sentence_3 = sentences[1].end_char
106 |
107 | example = FinCausalExample(
108 | example_id,
109 | context_text,
110 | cause_text,
111 | effect_text,
112 | offset_sentence_2,
113 | offset_sentence_3,
114 | cause_start_position_character,
115 | cause_end_position_character,
116 | effect_start_position_character,
117 | effect_end_position_character
118 | )
119 | examples.append(example)
120 | return examples
121 |
122 |
123 | def fincausal_convert_example_to_features_init(
124 | tokenizer_for_convert: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]):
125 | global tokenizer
126 | tokenizer = tokenizer_for_convert
127 |
128 |
129 | def _improve_answer_span(doc_tokens: List[str],
130 | input_start: int,
131 | input_end: int,
132 | tokenizer,
133 | orig_answer_text: str) -> Tuple[int, int]:
134 | """Returns tokenized answer spans that better match the annotated answer."""
135 | tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
136 |
137 | for new_start in range(input_start, input_end + 1):
138 | for new_end in range(input_end, new_start - 1, -1):
139 | text_span = " ".join(doc_tokens[new_start: (new_end + 1)])
140 | if text_span == tok_answer_text:
141 | return new_start, new_end
142 |
143 | return input_start, input_end
144 |
145 |
146 | def _check_is_max_context(doc_spans: Union[List[dict], List[UserDict]],
147 | cur_span_index: int,
148 | position: int) -> bool:
149 | """Check if this is the 'max context' doc span for the token."""
150 | best_score = None
151 | best_span_index = None
152 | for (span_index, doc_span) in enumerate(doc_spans):
153 | end = doc_span["start"] + doc_span["length"] - 1
154 | if position < doc_span["start"]:
155 | continue
156 | if position > end:
157 | continue
158 | num_left_context = position - doc_span["start"]
159 | num_right_context = end - position
160 | score = min(num_left_context, num_right_context) + 0.01 * doc_span["length"]
161 | if best_score is None or score > best_score:
162 | best_score = score
163 | best_span_index = span_index
164 |
165 | return cur_span_index == best_span_index
166 |
167 |
168 | def fincausal_convert_example_to_features(example: FinCausalExample,
169 | max_seq_length: int,
170 | doc_stride: int,
171 | is_training: bool) -> List[FinCausalFeatures]:
172 | features = []
173 | if is_training:
174 | # Get start and end position
175 | start_cause_position = example.start_cause_position
176 | end_cause_position = example.end_cause_position
177 | start_effect_position = example.start_effect_position
178 | end_effect_position = example.end_effect_position
179 |
180 | # If the cause cannot be found in the text, then skip this example.
181 | actual_cause_text = " ".join(example.doc_tokens[start_cause_position: (end_cause_position + 1)])
182 | cleaned_cause_text = " ".join(whitespace_tokenize(_run_split_on_punc(example.cause_text)))
183 | if actual_cause_text.find(cleaned_cause_text) == -1:
184 | logger.warning("Could not find cause: '%s' vs. '%s'", actual_cause_text, cleaned_cause_text)
185 | return []
186 |
187 | # If the effect cannot be found in the text, then skip this example.
188 | actual_effect_text = " ".join(example.doc_tokens[start_effect_position: (end_effect_position + 1)])
189 | cleaned_effect_text = " ".join(whitespace_tokenize(_run_split_on_punc(example.effect_text)))
190 | if actual_effect_text.find(cleaned_effect_text) == -1:
191 | logger.warning("Could not find effect: '%s' vs. '%s'", actual_effect_text, cleaned_effect_text)
192 | return []
193 |
194 | tok_to_orig_index = []
195 | orig_to_tok_index = []
196 | all_doc_tokens = []
197 | for (i, token) in enumerate(example.doc_tokens):
198 | orig_to_tok_index.append(len(all_doc_tokens))
199 | sub_tokens = tokenizer.tokenize(token)
200 | for sub_token in sub_tokens:
201 | tok_to_orig_index.append(i)
202 | all_doc_tokens.append(sub_token)
203 |
204 | if is_training:
205 | tok_cause_start_position = orig_to_tok_index[example.start_cause_position]
206 | if example.end_cause_position < len(example.doc_tokens) - 1:
207 | tok_cause_end_position = orig_to_tok_index[example.end_cause_position + 1] - 1
208 | else:
209 | tok_cause_end_position = len(all_doc_tokens) - 1
210 |
211 | (tok_cause_start_position, tok_cause_end_position) = _improve_answer_span(
212 | all_doc_tokens, tok_cause_start_position, tok_cause_end_position, tokenizer, example.cause_text
213 | )
214 |
215 | tok_effect_start_position = orig_to_tok_index[example.start_effect_position]
216 | if example.end_effect_position < len(example.doc_tokens) - 1:
217 | tok_effect_end_position = orig_to_tok_index[example.end_effect_position + 1] - 1
218 | else:
219 | tok_effect_end_position = len(all_doc_tokens) - 1
220 |
221 | (tok_effect_start_position, tok_effect_end_position) = _improve_answer_span(
222 | all_doc_tokens, tok_effect_start_position, tok_effect_end_position, tokenizer, example.effect_text
223 | )
224 | if example.offset_sentence_2 > 0:
225 | tok_sentence_2_offset = orig_to_tok_index[example.offset_sentence_2 + 1] - 1
226 | else:
227 | tok_sentence_2_offset = None
228 | if example.offset_sentence_3 > 0:
229 | tok_sentence_3_offset = orig_to_tok_index[example.offset_sentence_3 + 1] - 1
230 | else:
231 | tok_sentence_3_offset = None
232 |
233 | spans: List[BatchEncoding] = []
234 |
235 | sequence_added_tokens = tokenizer.max_len - tokenizer.max_len_single_sentence
236 |
237 | span_doc_tokens = all_doc_tokens
238 | while len(spans) * doc_stride < len(all_doc_tokens):
239 |
240 | encoded_dict: BatchEncoding = tokenizer.encode_plus(span_doc_tokens,
241 | max_length=max_seq_length,
242 | return_overflowing_tokens=True,
243 | pad_to_max_length=True,
244 | stride=max_seq_length - doc_stride - sequence_added_tokens - 1,
245 | truncation_strategy="only_first",
246 | truncation=True,
247 | return_token_type_ids=True,
248 | )
249 |
250 | paragraph_len = min(
251 | len(all_doc_tokens) - len(spans) * doc_stride,
252 | max_seq_length - sequence_added_tokens,
253 | )
254 | if tokenizer.pad_token_id in encoded_dict["input_ids"]:
255 | if tokenizer.padding_side == "right":
256 | non_padded_ids = encoded_dict.data["input_ids"][
257 | : encoded_dict.data["input_ids"].index(tokenizer.pad_token_id)]
258 | else:
259 | last_padding_id_position = (
260 | len(encoded_dict.data["input_ids"])
261 | - 1
262 | - encoded_dict["input_ids"][::-1].index(tokenizer.pad_token_id)
263 | )
264 | non_padded_ids = encoded_dict["input_ids"][last_padding_id_position + 1:]
265 | else:
266 | non_padded_ids = encoded_dict["input_ids"]
267 |
268 | tokens = tokenizer.convert_ids_to_tokens(non_padded_ids)
269 |
270 | token_to_orig_map = {}
271 | for i in range(paragraph_len):
272 | index = sequence_added_tokens + i
273 | token_to_orig_map[index] = tok_to_orig_index[len(spans) * doc_stride + i]
274 |
275 | encoded_dict["paragraph_len"] = paragraph_len
276 | encoded_dict["tokens"] = tokens
277 | encoded_dict["token_to_orig_map"] = token_to_orig_map
278 | encoded_dict["token_is_max_context"] = {}
279 | encoded_dict["start"] = len(spans) * doc_stride
280 | encoded_dict["length"] = paragraph_len
281 |
282 | spans.append(encoded_dict)
283 |
284 | if len(encoded_dict.get("overflowing_tokens", [])) == 0:
285 | break
286 | span_doc_tokens = encoded_dict["overflowing_tokens"]
287 |
288 | for doc_span_index in range(len(spans)):
289 | for j in range(spans[doc_span_index].data["paragraph_len"]):
290 | is_max_context = _check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j)
291 | spans[doc_span_index].data["token_is_max_context"][j] = is_max_context
292 |
293 | for span in spans:
294 | # Identify the position of the CLS token
295 | cls_index = span.data["input_ids"].index(tokenizer.cls_token_id)
296 |
297 | p_mask = np.ones(len(span.data["token_type_ids"]))
298 | p_mask[np.where(np.array(span.data["input_ids"]) == tokenizer.sep_token_id)[0]] = 1
299 | # Set the CLS index to '0'
300 | p_mask[cls_index] = 0
301 |
302 | span_is_impossible = False
303 | cause_start_position = 0
304 | cause_end_position = 0
305 | effect_start_position = 0
306 | effect_end_position = 0
307 | doc_start = span.data["start"]
308 | doc_end = span.data["start"] + span.data["length"] - 1
309 | out_of_span = False
310 | if tokenizer.padding_side == "left":
311 | doc_offset = 0
312 | else:
313 | doc_offset = sequence_added_tokens
314 | if tok_sentence_2_offset is not None:
315 | sentence_2_offset = tok_sentence_2_offset - doc_start + doc_offset
316 | else:
317 | sentence_2_offset = None
318 | if tok_sentence_3_offset is not None:
319 | sentence_3_offset = tok_sentence_3_offset - doc_start + doc_offset
320 | else:
321 | sentence_3_offset = None
322 | if is_training:
323 | # For training, if our document chunk does not contain an annotation
324 | # we throw it out, since there is nothing to predict.
325 | if not (tok_cause_start_position >= doc_start
326 | and tok_cause_end_position <= doc_end
327 | and tok_effect_start_position >= doc_start
328 | and tok_effect_end_position <= doc_end):
329 | out_of_span = True
330 |
331 | if out_of_span:
332 | cause_start_position = cls_index
333 | cause_end_position = cls_index
334 | effect_start_position = cls_index
335 | effect_end_position = cls_index
336 | span_is_impossible = True
337 | else:
338 | cause_start_position = tok_cause_start_position - doc_start + doc_offset
339 | cause_end_position = tok_cause_end_position - doc_start + doc_offset
340 | effect_start_position = tok_effect_start_position - doc_start + doc_offset
341 | effect_end_position = tok_effect_end_position - doc_start + doc_offset
342 |
343 | features.append(
344 | FinCausalFeatures(
345 | span["input_ids"],
346 | span["attention_mask"],
347 | span["token_type_ids"],
348 | cls_index,
349 | p_mask.tolist(),
350 | example_orig_index=example.example_id,
351 | example_index=0,
352 | unique_id=0,
353 | paragraph_len=span["paragraph_len"],
354 | token_is_max_context=span["token_is_max_context"],
355 | tokens=span["tokens"],
356 | token_to_orig_map=span["token_to_orig_map"],
357 | cause_start_position=cause_start_position,
358 | cause_end_position=cause_end_position,
359 | effect_start_position=effect_start_position,
360 | effect_end_position=effect_end_position,
361 | sentence_2_offset=sentence_2_offset,
362 | sentence_3_offset=sentence_3_offset,
363 | is_impossible=span_is_impossible,
364 | )
365 | )
366 | return features
367 |
368 |
369 | def fincausal_convert_examples_to_features(examples: List[FinCausalExample],
370 | tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
371 | max_seq_length: int,
372 | doc_stride: int,
373 | is_training: bool, return_dataset: Union[bool, str] = False,
374 | threads: int = 1) -> Union[List[FinCausalFeatures],
375 | Tuple[List[FinCausalFeatures], TensorDataset]]:
376 | with multiprocessing.Pool(threads,
377 | initializer=fincausal_convert_example_to_features_init,
378 | initargs=(tokenizer,)) as p:
379 | annotate_ = partial(
380 | fincausal_convert_example_to_features,
381 | max_seq_length=max_seq_length,
382 | doc_stride=doc_stride,
383 | is_training=is_training,
384 | )
385 | features: List[FinCausalFeatures] = list(
386 | tqdm(
387 | p.imap(annotate_, examples, chunksize=32),
388 | total=len(examples),
389 | desc="convert squad examples to features",
390 | position=0,
391 | leave=True
392 | )
393 | )
394 | new_features = []
395 | unique_id = 1000000000
396 | example_index = 0
397 | for example_features in tqdm(features, total=len(features), desc="add example index and unique id",
398 | position=0, leave=True):
399 | if not example_features:
400 | continue
401 | for example_feature in example_features:
402 | example_feature.example_index = example_index
403 | example_feature.unique_id = unique_id
404 | new_features.append(example_feature)
405 | unique_id += 1
406 | example_index += 1
407 | features = new_features
408 | del new_features
409 | if return_dataset == "pt":
410 |
411 | # Convert to Tensors and build dataset
412 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
413 | all_attention_masks = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
414 | all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
415 | all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
416 | all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
417 | all_is_impossible = torch.tensor([f.is_impossible for f in features], dtype=torch.float)
418 |
419 | if not is_training:
420 | all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
421 | dataset = TensorDataset(
422 | all_input_ids, all_attention_masks, all_token_type_ids, all_example_index, all_cls_index, all_p_mask
423 | )
424 | else:
425 | all_cause_start_positions = torch.tensor([f.cause_start_position for f in features], dtype=torch.long)
426 | all_cause_end_positions = torch.tensor([f.cause_end_position for f in features], dtype=torch.long)
427 | all_effect_start_positions = torch.tensor([f.effect_start_position for f in features], dtype=torch.long)
428 | all_effect_end_positions = torch.tensor([f.effect_end_position for f in features], dtype=torch.long)
429 | dataset = TensorDataset(
430 | all_input_ids,
431 | all_attention_masks,
432 | all_token_type_ids,
433 | all_cause_start_positions,
434 | all_cause_end_positions,
435 | all_effect_start_positions,
436 | all_effect_end_positions,
437 | all_cls_index,
438 | all_p_mask,
439 | all_is_impossible,
440 | )
441 |
442 | return features, dataset
443 | return features
444 |
445 |
446 | def _run_split_on_punc(text: str) -> str:
447 | chars = list(text)
448 | i = 0
449 | start_new_word = True
450 | output = []
451 | while i < len(chars):
452 | char = chars[i]
453 | if _is_punctuation(char):
454 | output.append([char])
455 | start_new_word = True
456 | else:
457 | if start_new_word:
458 | output.append([])
459 | start_new_word = False
460 | output[-1].append(char)
461 | i += 1
462 |
463 | return " ".join(["".join(x) for x in output])
464 |
--------------------------------------------------------------------------------
/src/fincausal_evaluation/task2_evaluate.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyrights YseopLab 2020. All rights reserved.
4 | # Original evaluation file from https://github.com/yseop/YseopLab/blob/develop/FNP_2020_FinCausal/scoring/task2/task2_evaluate.py
5 | """ task2_evaluate.py - Scoring program for Fincausal 2020 Task 2
6 |
7 | usage: task2_evaluate.py [-h] {from-folder,from-file} ...
8 |
9 | positional arguments:
10 | {from-folder,from-file}
11 | Use from-file for basic mode or from-folder for
12 | Codalab compatible mode
13 |
14 | Usage 1: Folder mode
15 |
16 | usage: task2_evaluate.py from-folder [-h] input output
17 |
18 | Codalab mode with input and output folders
19 |
20 | positional arguments:
21 | input input folder with ref (reference) and res (result) sub folders
22 | output output folder where score.txt is written
23 |
24 | optional arguments:
25 | -h, --help show this help message and exit
26 | task2_evaluate input output
27 |
28 | input, output folders must follow the Codalab competition convention for scoring bundle
29 | e.g.
30 | ├───input
31 | │ ├───ref
32 | │ └───res
33 | └───output
34 |
35 | Usage 2: File mode
36 |
37 | usage: task2_evaluate.py from-file [-h] [--ref_file REF_FILE] pred_file [score_file]
38 |
39 | Basic mode with path to input and output files
40 |
41 | positional arguments:
42 | ref_file reference file (default: ../../data/fnp2020-fincausal-task2.csv)
43 | pred_file prediction file to evaluate
44 | score_file path to output score file (or stdout if not provided)
45 |
46 | optional arguments:
47 | -h, --help show this help message and exit
48 | """
49 | import argparse
50 | import logging
51 | import os
52 | import unittest
53 |
54 | from collections import namedtuple
55 |
56 | import nltk
57 |
58 | from sklearn import metrics
59 |
60 |
61 | def build_token_index(text):
62 | """
63 | build a dictionary of all tokenized items from text with their respective positions in the text.
64 | E.g. "this is a basic example of a basic method" returns
65 | {'this': [0], 'is': [1], 'a': [2, 6], 'basic': [3, 7], 'example': [4], 'of': [5], 'method': [8]}
66 | :param text: reference text to index
67 | :return: dict() of text token, each token with their respective position(s) in the text
68 | """
69 | tokens = nltk.word_tokenize(text)
70 | token_index = {}
71 | for position, token in enumerate(tokens):
72 | if token in token_index:
73 | token_index[token].append(position)
74 | else:
75 | token_index[token] = [position]
76 | return tokens, token_index
77 |
78 |
79 | def get_tokens_sequence(text, token_index):
80 | tokens = nltk.word_tokenize(text)
81 | # build list of possible position for each token
82 | # positions = [word_index[word] for word in words]
83 | positions = []
84 | for token in tokens:
85 | if token in token_index:
86 | positions.append(token_index[token])
87 | continue
88 | # Special case when '.' is not tokenized properly
89 | alt_token = ''.join([token, '.'])
90 | if alt_token in token_index:
91 | logging.debug(f'tokenize fix ".": {alt_token}')
92 | positions.append(token_index[alt_token])
93 | # TODO: discard the next token if == '.' ?
94 | continue
95 | # Special case when/if ',' is not tokenized properly - TBC
96 | # alt_token = ''.join([token, ','])
97 | # if alt_token in token_index:
98 | # logging.debug(f'tokenize fix ",": {alt_token}')
99 | # positions.append(token_index[alt_token])
100 | # continue
101 | else:
102 | logging.warning(f'get_tokens_sequence "{token}" discarded')
103 | # No matching ? stop here
104 | if len(positions) == 0:
105 | return positions
106 | # recursively process the list of token positions to return combinations of consecutive tokens
107 | seqs = _get_sequences(*positions)
108 | # Note: several sequences can possibly be found in the reference text, when similar text patterns are repeated
109 | # always return the longest
110 | return max(seqs, key=len)
111 |
112 |
113 | def _get_sequences(*args, value=None, path=None):
114 | """
115 | Recursive method to select sequences of successive tokens using their position relative to the reference text.
116 | A sequence is the list of successive indexes in the tokenized reference text.
117 | Implemented as a product() of successive positions constrained by their
118 | :param args: list of list of positions
119 | :param value: position of the previous token (i.e. next token position must be in range [value+1, value+3]
120 | :param path: debugging - current sequence
121 | :return:
122 | """
123 | logging.debug(path)
124 | # end of recursion
125 | if len(args) == 1:
126 | if value is not None:
127 | # return items matching constraint (i.e. within range with previous token)
128 | return [x for x in args[0] if x > value and (x < value + 3)]
129 | else:
130 | # Special case where text is restricted to a single token
131 | # return all positions on first call (i.e. value is None)
132 | return [args[0]]
133 | else:
134 | # iterate over current token possible positions and combine with other tokens from recursive call
135 | # result is a list of explored sequences (i.e. list of list of positions)
136 | result = []
137 | for x in args[0]:
138 | # keep track of current explored sequence
139 | p = [x] if path is None else list(path + [x])
140 | #
141 | if value is None or (x > value and (x < value + 3)):
142 | seqs = _get_sequences(*args[1:], value=x, path=p)
143 | # when recursion returns empty list and current position match constraint (either only value
144 | # or value within range) add current position as a single result
145 | if len(seqs) == 0 and (value is None or (x > value and (x < value + 3))):
146 | result.append([x])
147 | else:
148 | # otherwise combine current position with recursion results (whether returned sequences are list
149 | # or single number) and add to the list of results for this token position
150 | for s in seqs:
151 | res = [x] + s if type(s) is list else [x, s]
152 | result.append(res)
153 | return result
154 |
155 |
156 | def encode_causal_tokens(text, cause, effect):
157 | """
158 | Encode text, cause and effect into a single list with each token represented by their respective
159 | class labels ('-','C','E')
160 | :param text: reference text
161 | :param cause: causal substring in reference text
162 | :param effect: effect substring in reference text
163 | :return: text string converted as a list of tuple(token, label)
164 | """
165 | # Get reference text tokens and token index
166 | logging.debug(f'Reference: {text}')
167 | words, wi = build_token_index(text)
168 | logging.debug(f'Token index: {wi}')
169 |
170 | # init labels with default class label
171 | labels = ['-' for _ in range(len(words))]
172 |
173 | # encode cause using token index
174 | logging.debug(f'Cause: {cause}')
175 | cause_seq = get_tokens_sequence(cause, wi)
176 | logging.debug(f'Cause seq.: {cause_seq}')
177 | for position in cause_seq:
178 | labels[position] = 'C'
179 |
180 | # encode effect using token index
181 | logging.debug(f'Effect: {effect}')
182 | effect_seq = get_tokens_sequence(effect, wi)
183 | logging.debug(f'Effect seq.: {effect_seq}')
184 | for position in effect_seq:
185 | labels[position] = 'E'
186 |
187 | logging.debug(labels)
188 |
189 | return zip(words, labels)
190 |
191 |
192 | def evaluate(truth, predict, classes):
193 | """
194 | Fincausal 2020 Task 2 evaluation: returns precision, recall and F1 comparing submitting data to reference data.
195 | :param truth: list of Task2Data(index, text, cause, effect, labels) - reference data set
196 | :param predict: list of Task2Data(index, text, cause, effect, labels) - submission data set
197 | :param classes: list of classes
198 | :return: tuple(precision, recall, f1, exact match)
199 | """
200 | exact_match = 0
201 | y_truth = []
202 | y_predict = []
203 | multi = {}
204 | # First pass - process text sections with single causal relations and store others in `multi` dict()
205 | for t, p in zip(truth, predict):
206 | # Process Exact Match
207 | exact_match += 1 if all([x == y for x, y in zip(t.labels, p.labels)]) else 0
208 | # PRF: Text section with multiple causal relationship ?
209 | if t.index.count('.') == 2:
210 | # extract root index and add to the list to be processed later
211 | root_index = '.'.join(t.index.split('.')[:-1])
212 | if root_index in multi:
213 | multi[root_index][0].append(t.labels)
214 | multi[root_index][1].append(p.labels)
215 | else:
216 | multi[root_index] = [[t.labels], [p.labels]]
217 | else:
218 | # Accumulate data for precision, recall, f1 scores
219 | y_truth.extend(t.labels)
220 | y_predict.extend(p.labels)
221 | # Second pass - deal with text sections having multiple causal relations
222 | for index, section in multi.items():
223 | # section[0] list of possible truth labels
224 | # section[1] list of predicted labels
225 | candidates = section[1]
226 | # for each possible combination of truth labels - try to find the best match in predicted labels
227 | # then repeat, removing this match from the list of remaining predicted labels
228 | for t in section[0]:
229 | best = None
230 | for p in candidates:
231 | f1 = metrics.f1_score(t, p, labels=classes, average='weighted', zero_division=0)
232 | if best is None or f1 > best[1]:
233 | best = (p, f1)
234 | # Use best to add to global evaluation
235 | y_truth.extend(t)
236 | y_predict.extend(best[0])
237 | # Remove best from list of candidate for next iteration
238 | candidates.remove(best[0])
239 | # Ensure all candidate predictions have been reviewed
240 | assert len(candidates) == 0
241 |
242 | precision, recall, f1, _ = metrics.precision_recall_fscore_support(y_truth, y_predict,
243 | labels=classes,
244 | average='weighted',
245 | zero_division=0)
246 |
247 | import numpy as np
248 | """
249 | Sklearn Multiclass confusion matrix is:
250 | y_true = ["cat", "ant", "cat", "cat", "ant", "bird"]
251 | y_pred = ["ant", "ant", "cat", "cat", "ant", "cat"]
252 | cmc = multilabel_confusion_matrix(y_true, y_pred, labels=["ant", "bird", "cat"])
253 | cmc
254 | array([[[3, 1],
255 | [0, 2]],
256 |
257 | [[5, 0],
258 | [1, 0]],
259 |
260 | [[2, 1],
261 | [1, 2]]])
262 |
263 | cm_ant = cmc[0]
264 | tp = cm_ant[1,1] has been classified as ant and is ant
265 | fp = cm_ant[0,1] has been classified as ant and is sth else
266 | fn = cm_ant[1,0] has been classified as sth else than ant and is ant
267 | tn = cm_ant[0,0] has been classified as sth else than ant and is something else
268 |
269 | """
270 | print(' ')
271 | print('raw metrics')
272 | print("#-------------------------------------------------------------------------------------#")
273 |
274 | MCM = metrics.multilabel_confusion_matrix(y_truth, y_predict, labels=classes)
275 | tp_sum = MCM[:, 1, 1]
276 | print('tp_sum ', tp_sum)
277 | pred_sum = tp_sum + MCM[:, 0, 1]
278 | print('predicted in the classes ', pred_sum)
279 | true_sum = tp_sum + MCM[:, 1, 0]
280 | print('actually in the classes (tp + fn: support) ', true_sum)
281 |
282 | """beta : float, 1.0 by default
283 | The strength of recall versus precision in the F-score."""
284 |
285 | precision_ = tp_sum / pred_sum
286 | recall_ = tp_sum / true_sum
287 | beta = 1.0
288 | beta2 = beta ** 2
289 | denom = beta2 * precision_ + recall_
290 | print('denom', denom)
291 | denom[denom == 0.] = 1 # avoid division by 0
292 | weights = true_sum
293 | print("weights", weights)
294 | print(' ')
295 | print("#-------------------------------------------------------------------------------------#")
296 | print('macro scores')
297 | f_score_ = (1 + beta2) * precision_ * recall_ / denom
298 | print('macro precision ', sum(precision_) / 3, 'macro recall', sum(recall_) / 3, 'macro f_score ',
299 | sum(f_score_) / 3)
300 |
301 | ## recompute for average from source
302 |
303 | precision_ = np.average(precision_, weights=weights)
304 | recall_ = np.average(recall_, weights=weights)
305 | f_score_ = np.average(f_score_, weights=weights)
306 | print(' ')
307 | print("#-------------------------------------------------------------------------------------#")
308 | print('recomputed weighted metrics ')
309 | print('weighted precision', precision_, 'weighted recall', recall_, 'weighted_fscore', f_score_)
310 | print(' ')
311 | print("#-------------------------------------------------------------------------------------#")
312 | print('classification report')
313 | print(metrics.classification_report(y_truth, y_predict, target_names=classes))
314 |
315 | logging.debug(f'SKLEARN EVAL: {f1}, {precision}, {recall}')
316 |
317 | return precision, recall, f1, float(exact_match) / float(len(truth))
318 |
319 |
320 | Task2Data = namedtuple('Task2Data', ['index', 'text', 'cause', 'effect', 'labels'])
321 |
322 |
323 | def get_data(csv_lines):
324 | """
325 | Retrieve Task 2 data from CSV content (separator is ';') as a list of (index, text, cause, effect).
326 | :param csv_lines:
327 | :return: list of Task2Data(index, text, cause, effect, labels)
328 | """
329 | result = []
330 | for line in csv_lines:
331 | line = line.rstrip('\n')
332 |
333 | index, text, cause, effect = line.split(';')[:4]
334 |
335 | text = text.lstrip()
336 | cause = cause.lstrip()
337 | effect = effect.lstrip()
338 |
339 | _, labels = zip(*encode_causal_tokens(text, cause, effect))
340 |
341 | result.append(Task2Data(index, text, cause, effect, labels))
342 |
343 | return result
344 |
345 |
346 | def evaluate_files(gold_file, submission_file, output_file=None):
347 | """
348 | Evaluate Precision, Recall, F1 scores between gold_file and submission_file
349 | If output_file is provided, scores are saved in this file and printed to std output.
350 | :param gold_file: path to reference data
351 | :param submission_file: path to submitted data
352 | :param output_file: path to output file as expected by Codalab competition framework
353 | :return:
354 | """
355 | if os.path.exists(gold_file) and os.path.exists(submission_file):
356 | with open(gold_file, 'r', encoding='utf-8') as fp:
357 | ref_csv = fp.readlines()
358 | with open(submission_file, 'r', encoding='utf-8') as fp:
359 | sub_csv = fp.readlines()
360 |
361 | # Get data (skipping headers)
362 | logging.info('* Loading reference data')
363 | y_true = get_data(ref_csv[1:])
364 | logging.info('* Loading prediction data')
365 | y_pred = get_data(sub_csv[1:])
366 |
367 | logging.info(f'Load Data: check data set length = {len(y_true) == len(y_pred)}')
368 | logging.info(f'Load Data: check data set ref. text = {all([x.text == y.text for x, y in zip(y_true, y_pred)])}')
369 | assert len(y_true) == len(y_pred)
370 | assert all([x.text == y.text for x, y in zip(y_true, y_pred)])
371 |
372 | # Process data using classes: -, C & E
373 | precision, recall, f1, exact_match = evaluate(y_true, y_pred, ['-', 'C', 'E'])
374 |
375 | scores = [
376 | "F1: %f\n" % f1,
377 | "Recall: %f\n" % recall,
378 | "Precision: %f\n" % precision,
379 | "ExactMatch: %f\n" % exact_match
380 | ]
381 |
382 | for s in scores:
383 | print(s, end='')
384 | if output_file is not None:
385 | with open(output_file, 'w', encoding='utf-8') as fp:
386 | for s in scores:
387 | fp.write(s)
388 | else:
389 | # Submission file most likely being the wrong one - tell which one we are looking for
390 | logging.error(f'{os.path.basename(gold_file)} not found')
391 |
392 | ## Save for control
393 | import pandas as pd
394 | df = pd.DataFrame.from_records(y_true)
395 | df.columns = ['Index', 'Text', 'Cause', 'Effect', 'TRUTH']
396 | dfpred = pd.DataFrame.from_records(y_pred)
397 | dfpred.columns = ['Index', 'Text', 'Cause', 'Effect', 'PRED']
398 | df['PRED'] = dfpred['PRED']
399 | df['TRUTH'] = df['TRUTH'].apply(lambda x: ' '.join(x))
400 | df['PRED'] = df['PRED'].apply(lambda x: ' '.join(x))
401 |
402 | ctrlpath = submission_file.split('/')
403 | ctrlpath.pop()
404 | ctrlpath = '/'.join([path_ for path_ in ctrlpath])
405 | df.to_csv(os.path.join(ctrlpath, 'origin_control.csv'), header=1, index=0)
406 |
407 |
408 | def from_folder(args):
409 | # Folder mode - Codalab usage
410 | submit_dir = os.path.join(args.input, 'res')
411 | truth_dir = os.path.join(args.input, 'ref')
412 | output_dir = args.output
413 |
414 | if not os.path.isdir(submit_dir):
415 | logging.error("%s doesn't exist" % submit_dir)
416 |
417 | if os.path.isdir(submit_dir) and os.path.isdir(truth_dir):
418 | if not os.path.exists(output_dir):
419 | os.makedirs(output_dir)
420 |
421 | o_file = os.path.join(output_dir, 'scores.txt')
422 |
423 | gold_list = os.listdir(truth_dir)
424 | for gold in gold_list:
425 | g_file = os.path.join(truth_dir, gold)
426 | s_file = os.path.join(submit_dir, gold)
427 |
428 | evaluate_files(g_file, s_file, o_file)
429 |
430 |
431 | def from_file(args):
432 | return evaluate_files(args.ref_file, args.pred_file, args.score_file)
433 |
434 |
435 | def main():
436 | parser = argparse.ArgumentParser()
437 | subparsers = parser.add_subparsers(help='Use from-file for basic mode or from-folder for Codalab compatible mode')
438 |
439 | command1_parser = subparsers.add_parser('from-folder', description='Codalab mode with input and output folders')
440 | command1_parser.set_defaults(func=from_folder)
441 | command1_parser.add_argument('input', help='input folder with ref (reference) and res (result) sub folders')
442 | command1_parser.add_argument('output', help='output folder where score.txt is written')
443 |
444 | command2_parser = subparsers.add_parser('from-file', description='Basic mode with path to input and output files')
445 | command2_parser.set_defaults(func=from_file)
446 | command2_parser.add_argument('--ref_file', default='../../data/fnp2020-fincausal-task2.csv', help='reference file')
447 | command2_parser.add_argument('pred_file', help='prediction file to evaluate')
448 | command2_parser.add_argument('score_file', nargs='?', default=None,
449 | help='path to output score file (or stdout if not provided)')
450 |
451 | logging.basicConfig(level=logging.INFO,
452 | filename=None,
453 | format='%(levelname)-7s| %(message)s')
454 |
455 | args = parser.parse_args()
456 | if 'func' in args:
457 | exit(args.func(args))
458 | else:
459 | parser.print_usage()
460 | exit(1)
461 |
462 |
463 | if __name__ == '__main__':
464 | main()
465 |
466 |
467 | # Tests, which can be executed with `python -m unittest task2_evaluate`.
468 | class Test(unittest.TestCase):
469 | def _process_test_ok(self, t_text, p_text, t_labels, p_labels, f1, precision, recall, exact_match):
470 | # Load data
471 | y_true = get_data(t_text)
472 | y_pred = get_data(p_text)
473 | with self.subTest(value='encode_causal_truth'):
474 | self.assertEqual(y_true[0].labels, t_labels)
475 | with self.subTest(value='encode_causal_pred'):
476 | self.assertEqual(y_pred[0].labels, p_labels)
477 | # Evaluate precision, recall, f1 and exact matches
478 | result = evaluate(y_true, y_pred, ['-', 'C', 'E'])
479 | # Round result to 2 decimals
480 | result = tuple(map(lambda x: round(x, 2), result))
481 | with self.subTest(value='evaluate'):
482 | self.assertEqual(result, (precision, recall, f1, exact_match))
483 |
484 | def test_0(self):
485 | """ Identity """
486 | self._process_test_ok(
487 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj.; aaa bbbb; hhh i jjjj.\n'],
488 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj.; aaa bbbb; hhh i jjjj.\n'],
489 | ('C', 'C', '-', '-', '-', '-', '-', 'E', 'E', 'E', 'E'),
490 | ('C', 'C', '-', '-', '-', '-', '-', 'E', 'E', 'E', 'E'),
491 | 1.00, 1.00, 1.00, 1.00
492 | )
493 |
494 | def test_1(self):
495 | """ single failure """
496 | self._process_test_ok(
497 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; aaa bbbb; hhh i jjjj\n'],
498 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; bbbb; hhh i jjjj\n'],
499 | ('C', 'C', '-', '-', '-', '-', '-', 'E', 'E', 'E'),
500 | ('-', 'C', '-', '-', '-', '-', '-', 'E', 'E', 'E'),
501 | 0.92, 0.9, 0.89, 0.0
502 | )
503 |
504 | def test_2(self):
505 | """ 1 missed 1 error """
506 | self._process_test_ok(
507 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; aaa bbbb; hhh i jjjj\n'],
508 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; bbbb; gggg hhh i jjjj\n'],
509 | ('C', 'C', '-', '-', '-', '-', '-', 'E', 'E', 'E'),
510 | ('-', 'C', '-', '-', '-', '-', 'E', 'E', 'E', 'E'),
511 | 0.82, 0.8, 0.79, 0.0
512 | )
513 |
514 | def test_3(self):
515 | """ total failure """
516 | self._process_test_ok(
517 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj.; aaa bbbb; hhh i jjjj.\n'],
518 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj.; ccc d; eeee ffff gggg\n'],
519 | ('C', 'C', '-', '-', '-', '-', '-', 'E', 'E', 'E', 'E'),
520 | ('-', '-', 'C', 'C', 'E', 'E', 'E', '-', '-', '-', '-'),
521 | 0.0, 0.0, 0.0, 0.0
522 | )
523 |
524 | def test_4(self):
525 | """ punctuation """
526 | self._process_test_ok(
527 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj.; aaa bbbb; hhh i jjjj.\n'],
528 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj.; aaa bbbb; hhh i jjjj\n'],
529 | ('C', 'C', '-', '-', '-', '-', '-', 'E', 'E', 'E', 'E'),
530 | ('C', 'C', '-', '-', '-', '-', '-', 'E', 'E', 'E', '-'),
531 | 0.92, 0.91, 0.91, 0.0
532 | )
533 |
534 | def test_5(self):
535 | """ 2 missed + punctuation """
536 | self._process_test_ok(
537 | ['1.0; aaa bbbb ccc d aaa bbbb ccc hhh i jjjj.; aaa bbbb ccc d; hhh i jjjj.\n'],
538 | ['1.0; aaa bbbb ccc d aaa bbbb ccc hhh i jjjj.; aaa bbbb; hhh i jjjj\n'],
539 | ('C', 'C', 'C', 'C', '-', '-', '-', 'E', 'E', 'E', 'E'),
540 | ('C', 'C', '-', '-', '-', '-', '-', 'E', 'E', 'E', '-'),
541 | 0.86, 0.73, 0.74, 0.0
542 | )
543 |
544 | def test_6(self):
545 | """ non consecutive tokens (out of range) """
546 | self._process_test_ok(
547 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; aaa bbbb ccc d; hhh i jjjj\n'],
548 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; aaa d; hhh i jjjj\n'],
549 | ('C', 'C', 'C', 'C', '-', '-', '-', 'E', 'E', 'E'),
550 | ('C', '-', '-', '-', '-', '-', '-', 'E', 'E', 'E'),
551 | 0.85, 0.7, 0.66, 0.0
552 | )
553 |
554 | def test_7(self):
555 | """ non consecutive tokens (in range) """
556 | self._process_test_ok(
557 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; aaa bbbb ccc d; hhh i jjjj\n'],
558 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; bbbb d; hhh i jjjj\n'],
559 | ('C', 'C', 'C', 'C', '-', '-', '-', 'E', 'E', 'E'),
560 | ('-', 'C', '-', 'C', '-', '-', '-', 'E', 'E', 'E'),
561 | 0.88, 0.8, 0.79, 0.0
562 | )
563 |
564 | def test_8(self):
565 | """ no effect """
566 | self._process_test_ok(
567 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; bbbb ccc d; hhh i jjjj\n'],
568 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; bbbb ccc d;\n'],
569 | ('-', 'C', 'C', 'C', '-', '-', '-', 'E', 'E', 'E'),
570 | ('-', 'C', 'C', 'C', '-', '-', '-', '-', '-', '-'),
571 | 0.53, 0.7, 0.59, 0.0
572 | )
573 |
574 | def test_9(self):
575 | """ no cause """
576 | self._process_test_ok(
577 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; bbbb ccc d; hhh i jjjj\n'],
578 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; ; bbbb ccc d\n'],
579 | ('-', 'C', 'C', 'C', '-', '-', '-', 'E', 'E', 'E'),
580 | ('-', 'E', 'E', 'E', '-', '-', '-', '-', '-', '-'),
581 | 0.23, 0.4, 0.29, 0.0
582 | )
583 |
584 | def test_10(self):
585 | """ no cause, no effect """
586 | self._process_test_ok(
587 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; bbbb ccc d; hhh i jjjj\n'],
588 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; ; \n'],
589 | ('-', 'C', 'C', 'C', '-', '-', '-', 'E', 'E', 'E'),
590 | ('-', '-', '-', '-', '-', '-', '-', '-', '-', '-'),
591 | 0.16, 0.4, 0.23, 0.0
592 | )
593 |
594 | def test_11(self):
595 | """ all cause """
596 | self._process_test_ok(
597 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; bbbb ccc d; hhh i jjjj\n'],
598 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; \n'],
599 | ('-', 'C', 'C', 'C', '-', '-', '-', 'E', 'E', 'E'),
600 | ('C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C'),
601 | 0.09, 0.3, 0.14, 0.0
602 | )
603 |
604 | def test_12(self):
605 | """ half cause """
606 | self._process_test_ok(
607 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; bbbb ccc d; hhh i jjjj\n'],
608 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; aaa ccc eeee gggg i;\n'],
609 | ('-', 'C', 'C', 'C', '-', '-', '-', 'E', 'E', 'E'),
610 | ('C', '-', 'C', '-', 'C', '-', 'C', '-', 'C', '-'),
611 | 0.14, 0.2, 0.16, 0.0
612 | )
613 |
614 | def test_13(self):
615 | """ all effect """
616 | self._process_test_ok(
617 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; bbbb ccc d; hhh i jjjj\n'],
618 | ['1.0; aaa bbbb ccc d eeee ffff gggg hhh i jjjj; ; aaa bbbb ccc d eeee ffff gggg hhh i jjjj\n'],
619 | ('-', 'C', 'C', 'C', '-', '-', '-', 'E', 'E', 'E'),
620 | ('E', 'E', 'E', 'E', 'E', 'E', 'E', 'E', 'E', 'E'),
621 | 0.09, 0.3, 0.14, 0.0
622 | )
623 |
--------------------------------------------------------------------------------
/src/evaluation.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3 | # Copyright 2020 Guillaume Becquin.
4 | # MODIFIED FOR CAUSE EFFECT EXTRACTION
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 | import collections
19 | import csv
20 | import json
21 | import logging
22 | import math
23 | from pathlib import Path
24 | from typing import Dict, List, Tuple, Union
25 |
26 | import torch
27 | from torch.nn import Module
28 | from torch.utils.data import SequentialSampler, DataLoader
29 | from tqdm import tqdm
30 | from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
31 |
32 | from .config import RunConfig
33 | from .fincausal_evaluation.task2_evaluate import encode_causal_tokens, Task2Data
34 | from .data import FinCausalResult, FinCausalFeatures, FinCausalExample
35 | from .preprocessing import load_examples
36 | from .fincausal_evaluation.task2_evaluate import evaluate as official_evaluate
37 |
38 | logger = logging.getLogger(__name__)
39 |
40 |
41 | def to_list(tensor: torch.Tensor) -> List:
42 | return tensor.detach().cpu().tolist()
43 |
44 |
45 | def predict(model: Module,
46 | tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
47 | device: torch.device,
48 | file_path: Path,
49 | model_type: str,
50 | output_dir: Path,
51 | run_config: RunConfig) -> Tuple[List[FinCausalExample], collections.OrderedDict]:
52 | dataset, examples, features = load_examples(file_path=file_path,
53 | tokenizer=tokenizer,
54 | output_examples=True,
55 | evaluate=True,
56 | run_config=run_config)
57 |
58 | if not output_dir.is_dir():
59 | output_dir.mkdir(parents=True, exist_ok=True)
60 |
61 | eval_sampler = SequentialSampler(dataset)
62 | eval_dataloader = DataLoader(dataset,
63 | sampler=eval_sampler,
64 | batch_size=run_config.eval_batch_size)
65 |
66 | # Start evaluation
67 | logger.info("***** Running evaluation *****")
68 | logger.info(" Num examples = %d", len(dataset))
69 | logger.info(" Batch size = %d", run_config.eval_batch_size)
70 |
71 | all_results = []
72 | sequence_added_tokens = tokenizer.max_len - tokenizer.max_len_single_sentence
73 |
74 | for batch in tqdm(eval_dataloader, desc="Evaluating", position=0, leave=True):
75 | model.eval()
76 | batch = tuple(t.to(device) for t in batch)
77 |
78 | with torch.no_grad():
79 | inputs = {
80 | "input_ids": batch[0],
81 | "attention_mask": batch[1],
82 | "token_type_ids": batch[2],
83 | }
84 |
85 | if model_type in ["xlm", "roberta", "distilbert", "camembert"]:
86 | del inputs["token_type_ids"]
87 |
88 | example_indices = batch[3]
89 | outputs = model(**inputs)
90 |
91 | for i, example_index in enumerate(example_indices):
92 | eval_feature = features[example_index.item()]
93 | unique_id = int(eval_feature.unique_id)
94 |
95 | output = [to_list(output[i]) for output in outputs]
96 | start_cause_logits, end_cause_logits, start_effect_logits, end_effect_logits = output
97 | result = FinCausalResult(unique_id,
98 | start_cause_logits, end_cause_logits,
99 | start_effect_logits, end_effect_logits)
100 |
101 | all_results.append(result)
102 |
103 | # Compute predictions
104 | predictions = compute_predictions_logits(
105 | examples,
106 | features,
107 | all_results,
108 | output_dir,
109 | sequence_added_tokens,
110 | run_config
111 | )
112 |
113 | return examples, predictions
114 |
115 |
116 | def evaluate(model: Module,
117 | tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
118 | device: torch.device,
119 | file_path: Path,
120 | model_type: str,
121 | output_dir: Path,
122 | run_config: RunConfig) -> Dict:
123 | examples, predictions = predict(model=model,
124 | tokenizer=tokenizer,
125 | device=device,
126 | file_path=file_path,
127 | model_type=model_type,
128 | output_dir=output_dir,
129 | run_config=run_config)
130 |
131 | # Compute the F1 and exact scores.
132 | results, correct, wrong = compute_metrics(examples, predictions)
133 | output_prediction_file_correct = output_dir / "predictions_correct.json"
134 | output_prediction_file_wrong = output_dir / "predictions_wrong.json"
135 |
136 | with output_prediction_file_correct.open('w') as writer:
137 | writer.write(json.dumps(correct, indent=4) + "\n")
138 |
139 | with output_prediction_file_wrong.open('w') as writer:
140 | writer.write(json.dumps(wrong, indent=4) + "\n")
141 |
142 | return results
143 |
144 |
145 | def get_data_from_list(input_data: List[List[str]]) -> List[Task2Data]:
146 | """
147 | :param input_data: list of inputs (example id, text, cause, effect)
148 | :return: list of Task2Data(index, text, cause, effect, labels)
149 | """
150 | result = []
151 | for example in input_data:
152 | (index, text, cause, effect) = tuple(example)
153 |
154 | text = text.lstrip()
155 | cause = cause.lstrip()
156 | effect = effect.lstrip()
157 |
158 | _, labels = zip(*encode_causal_tokens(text, cause, effect))
159 |
160 | result.append(Task2Data(index, text, cause, effect, labels))
161 |
162 | return result
163 |
164 |
165 | def compute_metrics(examples: List[FinCausalExample], predictions: collections.OrderedDict) \
166 | -> Tuple[Dict, List[Dict], List[Dict]]:
167 | y_true = []
168 | y_pred = []
169 |
170 | for example in examples:
171 | y_true.append([example.example_id, example.context_text, example.cause_text, example.effect_text])
172 | prediction = predictions[example.example_id]
173 | y_pred.append([example.example_id, example.context_text, prediction['cause_text'], prediction['effect_text']])
174 |
175 | all_correct = list()
176 | all_wrong = list()
177 | for y_true_ex, y_pred_ex in zip(y_true, y_pred):
178 | if (y_true_ex[2] == y_pred_ex[2]) and (y_true_ex[3] == y_pred_ex[3]):
179 | all_correct.append({'text': y_true_ex[1],
180 | 'cause_true': y_true_ex[2],
181 | 'effect_true': y_true_ex[3],
182 | 'cause_pred': y_pred_ex[2],
183 | 'effect_pred': y_pred_ex[3]
184 | })
185 | else:
186 | all_wrong.append({'text': y_true_ex[1],
187 | 'cause_true': y_true_ex[2],
188 | 'effect_true': y_true_ex[3],
189 | 'cause_pred': y_pred_ex[2],
190 | 'effect_pred': y_pred_ex[3]
191 | })
192 | logging.info('* Loading reference data')
193 | y_true = get_data_from_list(y_true)
194 | logging.info('* Loading prediction data')
195 | y_pred = get_data_from_list(y_pred)
196 | logging.info(f'Load Data: check data set length = {len(y_true) == len(y_pred)}')
197 | logging.info(f'Load Data: check data set ref. text = {all([x.text == y.text for x, y in zip(y_true, y_pred)])}')
198 | assert len(y_true) == len(y_pred)
199 | assert all([x.text == y.text for x, y in zip(y_true, y_pred)])
200 |
201 | precision, recall, f1, exact_match = official_evaluate(y_true, y_pred, ['-', 'C', 'E'])
202 |
203 | scores = [
204 | "F1: %f\n" % f1,
205 | "Recall: %f\n" % recall,
206 | "Precision: %f\n" % precision,
207 | "ExactMatch: %f\n" % exact_match
208 | ]
209 | for s in scores:
210 | print(s, end='')
211 |
212 | return {
213 | 'F1score:': f1,
214 | 'Precision: ': precision,
215 | 'Recall: ': recall,
216 | 'exact match: ': exact_match
217 | }, all_correct, all_wrong
218 |
219 |
220 | _PrelimPrediction = collections.namedtuple(
221 | "PrelimPrediction", ["feature_index",
222 | "start_index_cause",
223 | "end_index_cause",
224 | "start_logit_cause",
225 | "end_logit_cause",
226 | "start_index_effect",
227 | "end_index_effect",
228 | "start_logit_effect",
229 | "end_logit_effect"]
230 | )
231 |
232 | _NbestPrediction = collections.namedtuple(
233 | "NbestPrediction", ["text_cause",
234 | "start_index_cause",
235 | "end_index_cause",
236 | "start_logit_cause",
237 | "end_logit_cause",
238 | "text_effect",
239 | "start_index_effect",
240 | "end_index_effect",
241 | "start_logit_effect",
242 | "end_logit_effect"]
243 | )
244 |
245 |
246 | def filter_impossible_spans(features,
247 | unique_id_to_result: Dict,
248 | n_best_size: int,
249 | max_answer_length: int,
250 | sequence_added_tokens: int,
251 | sentence_boundary_heuristic: bool = False,
252 | full_sentence_heuristic: bool = False,
253 | shared_sentence_heuristic: bool = False,
254 | ) -> List[_PrelimPrediction]:
255 | prelim_predictions = []
256 |
257 | for (feature_index, feature) in enumerate(features):
258 | result = unique_id_to_result[feature.unique_id]
259 | assert isinstance(feature, FinCausalFeatures)
260 | assert isinstance(result, FinCausalResult)
261 | sentence_offsets = [offset for offset in [feature.sentence_2_offset, feature.sentence_3_offset] if
262 | offset is not None]
263 | start_indexes_cause = _get_best_indexes(result.start_cause_logits, n_best_size)
264 | end_indexes_cause = _get_best_indexes(result.end_cause_logits, n_best_size)
265 | start_logits_cause = result.start_cause_logits
266 | end_logits_cause = result.end_cause_logits
267 | start_indexes_effect = _get_best_indexes(result.start_effect_logits, n_best_size)
268 | end_indexes_effect = _get_best_indexes(result.end_effect_logits, n_best_size)
269 | start_logits_effect = result.start_effect_logits
270 | end_logits_effect = result.end_effect_logits
271 |
272 | for raw_start_index_cause in start_indexes_cause:
273 | for raw_end_index_cause in end_indexes_cause:
274 | cause_pairs = [(raw_start_index_cause, raw_end_index_cause)]
275 | # Heuristic: a effect of a cause cannot span across multiple sentences
276 | if len(sentence_offsets) > 0 and sentence_boundary_heuristic:
277 | for sentence_offset in sentence_offsets:
278 | if raw_start_index_cause < sentence_offset < raw_end_index_cause:
279 | cause_pairs = [(raw_start_index_cause, sentence_offset),
280 | (sentence_offset + 1, raw_end_index_cause)]
281 | for start_index_cause, end_index_cause in cause_pairs:
282 | for raw_start_index_effect in start_indexes_effect:
283 | for raw_end_index_effect in end_indexes_effect:
284 | effect_pairs = [(raw_start_index_effect, raw_end_index_effect)]
285 | # Heuristic: a effect of a cause cannot span across multiple sentences
286 | if len(sentence_offsets) > 0 and sentence_boundary_heuristic:
287 | for sentence_offset in sentence_offsets:
288 | if raw_start_index_effect < sentence_offset < raw_end_index_effect:
289 | effect_pairs = [(raw_start_index_effect, sentence_offset),
290 | (sentence_offset + 1, raw_end_index_effect)]
291 | for start_index_effect, end_index_effect in effect_pairs:
292 | if (start_index_cause <= start_index_effect) and (
293 | end_index_cause >= start_index_effect):
294 | continue
295 | if (start_index_effect <= start_index_cause) and (
296 | end_index_effect >= start_index_cause):
297 | continue
298 | if start_index_effect >= len(feature.tokens) or start_index_cause >= len(
299 | feature.tokens):
300 | continue
301 | if end_index_effect >= len(feature.tokens) or end_index_cause >= len(feature.tokens):
302 | continue
303 | if start_index_effect not in feature.token_to_orig_map or \
304 | start_index_cause not in feature.token_to_orig_map:
305 | continue
306 | if end_index_effect not in feature.token_to_orig_map or \
307 | end_index_cause not in feature.token_to_orig_map:
308 | continue
309 | if (not feature.token_is_max_context.get(start_index_effect, False)) or \
310 | (not feature.token_is_max_context.get(start_index_cause, False)):
311 | continue
312 | if end_index_cause < start_index_cause:
313 | continue
314 | if end_index_effect < start_index_effect:
315 | continue
316 | length_cause = end_index_cause - start_index_cause + 1
317 | length_effect = end_index_effect - start_index_effect + 1
318 | if length_cause > max_answer_length:
319 | continue
320 | if length_effect > max_answer_length:
321 | continue
322 |
323 | # Heuristics extending the prediction spans
324 | if full_sentence_heuristic or shared_sentence_heuristic:
325 | num_tokens = len(feature.tokens)
326 | all_sentence_offsets = [sequence_added_tokens] + \
327 | [offset + 1 for offset in sentence_offsets] + \
328 | [num_tokens]
329 | cause_sentences = []
330 | effect_sentences = []
331 | for sentence_idx in range(len(all_sentence_offsets) - 1):
332 | sentence_start, sentence_end = all_sentence_offsets[sentence_idx], \
333 | all_sentence_offsets[sentence_idx + 1]
334 | if sentence_start <= start_index_cause < sentence_end:
335 | cause_sentences.append(sentence_idx)
336 | if sentence_start <= start_index_effect < sentence_end:
337 | effect_sentences.append(sentence_idx)
338 |
339 | # Heuristic (first rule): if a sentence contains only 1 clause the clause is
340 | # extended to the entire sentence.
341 | if set(cause_sentences).isdisjoint(set(effect_sentences)) \
342 | and full_sentence_heuristic:
343 | start_index_cause = min(
344 | [all_sentence_offsets[sent] for sent in cause_sentences])
345 | end_index_cause = max(
346 | [all_sentence_offsets[sent + 1] - 1 for sent in cause_sentences])
347 | start_index_effect = min(
348 | [all_sentence_offsets[sent] for sent in effect_sentences])
349 | end_index_effect = max(
350 | [all_sentence_offsets[sent + 1] - 1 for sent in effect_sentences])
351 | # Heuristic (third rule): if a sentence contains only 2 clauses the span is
352 | # extended as much as possible.
353 | if not set(cause_sentences).isdisjoint(set(effect_sentences)) \
354 | and shared_sentence_heuristic \
355 | and len(cause_sentences) == 1 \
356 | and len(effect_sentences) == 1:
357 | if start_index_cause < start_index_effect:
358 | start_index_cause = min(
359 | [all_sentence_offsets[sent] for sent in cause_sentences])
360 | end_index_effect = max(
361 | [all_sentence_offsets[sent + 1] - 1 for sent in effect_sentences])
362 | else:
363 | start_index_effect = min(
364 | [all_sentence_offsets[sent] for sent in effect_sentences])
365 | end_index_cause = max(
366 | [all_sentence_offsets[sent + 1] - 1 for sent in cause_sentences])
367 |
368 | prelim_predictions.append(
369 | _PrelimPrediction(
370 | feature_index=feature_index,
371 | start_index_cause=start_index_cause,
372 | end_index_cause=end_index_cause,
373 | start_logit_cause=start_logits_cause[start_index_cause],
374 | end_logit_cause=end_logits_cause[end_index_cause],
375 | start_index_effect=start_index_effect,
376 | end_index_effect=end_index_effect,
377 | start_logit_effect=start_logits_effect[start_index_effect],
378 | end_logit_effect=end_logits_effect[end_index_effect]
379 | )
380 | )
381 | return prelim_predictions
382 |
383 |
384 | def get_predictions(preliminary_predictions: List[_PrelimPrediction], n_best_size: int,
385 | features: List[FinCausalFeatures], example: FinCausalExample) -> List[_NbestPrediction]:
386 | seen_predictions_cause = {}
387 | seen_predictions_effect = {}
388 | nbest = []
389 | for prediction in preliminary_predictions:
390 | if len(nbest) >= n_best_size:
391 | break
392 | feature = features[prediction.feature_index]
393 | if prediction.start_index_cause > 0: # this is a non-null prediction
394 | orig_doc_start_cause = feature.token_to_orig_map[prediction.start_index_cause]
395 | orig_doc_end_cause = feature.token_to_orig_map[prediction.end_index_cause]
396 | orig_doc_start_cause_char = example.word_to_char_mapping[orig_doc_start_cause]
397 | if orig_doc_end_cause < len(example.word_to_char_mapping) - 1:
398 | orig_doc_end_cause_char = example.word_to_char_mapping[orig_doc_end_cause + 1]
399 | else:
400 | orig_doc_end_cause_char = len(example.context_text)
401 | final_text_cause = example.context_text[orig_doc_start_cause_char: orig_doc_end_cause_char]
402 | final_text_cause = final_text_cause.strip()
403 |
404 | orig_doc_start_effect = feature.token_to_orig_map[prediction.start_index_effect]
405 | orig_doc_end_effect = feature.token_to_orig_map[prediction.end_index_effect]
406 | orig_doc_start_effect_char = example.word_to_char_mapping[orig_doc_start_effect]
407 | if orig_doc_end_effect < len(example.word_to_char_mapping) - 1:
408 | orig_doc_end_effect_char = example.word_to_char_mapping[orig_doc_end_effect + 1]
409 | else:
410 | orig_doc_end_effect_char = len(example.context_text)
411 | final_text_effect = example.context_text[orig_doc_start_effect_char: orig_doc_end_effect_char]
412 | final_text_effect = final_text_effect.strip()
413 |
414 | if final_text_cause in seen_predictions_cause and final_text_effect in seen_predictions_effect:
415 | continue
416 |
417 | seen_predictions_cause[final_text_cause] = True
418 | seen_predictions_cause[final_text_effect] = True
419 | else:
420 | final_text_cause = final_text_effect = ""
421 | seen_predictions_cause[final_text_cause] = True
422 | seen_predictions_cause[final_text_effect] = True
423 | orig_doc_start_cause = prediction.start_index_cause
424 | orig_doc_end_cause = prediction.end_index_cause
425 | orig_doc_start_effect = prediction.end_index_effect
426 | orig_doc_end_effect = prediction.end_index_effect
427 |
428 | nbest.append(
429 | _NbestPrediction(text_cause=final_text_cause,
430 | start_logit_cause=prediction.start_logit_cause,
431 | end_logit_cause=prediction.end_logit_cause,
432 | start_index_cause=orig_doc_start_cause,
433 | end_index_cause=orig_doc_end_cause,
434 | text_effect=final_text_effect,
435 | start_logit_effect=prediction.start_logit_effect,
436 | end_logit_effect=prediction.end_logit_effect,
437 | start_index_effect=orig_doc_start_effect,
438 | end_index_effect=orig_doc_end_effect,
439 | ))
440 | return nbest
441 |
442 |
443 | def compute_predictions_logits(
444 | all_examples: List[FinCausalExample],
445 | all_features: List[FinCausalFeatures],
446 | all_results: List[FinCausalResult],
447 | output_dir: Path,
448 | sequence_added_tokens: int,
449 | run_config: RunConfig) -> collections.OrderedDict:
450 | example_index_to_features = collections.defaultdict(list)
451 | for feature in all_features:
452 | example_index_to_features[feature.example_index].append(feature)
453 |
454 | unique_id_to_result = {}
455 | for result in all_results:
456 | unique_id_to_result[result.unique_id] = result
457 |
458 | all_predictions = collections.OrderedDict()
459 | all_nbest_json = collections.OrderedDict()
460 |
461 | for (example_index, example) in enumerate(all_examples):
462 | features = example_index_to_features[example_index]
463 | suffix_index = 0
464 | if example.example_id.count('.') == 2 and run_config.top_n_sentences:
465 | suffix_index = int(example.example_id.split('.')[-1])
466 | prelim_predictions = filter_impossible_spans(features,
467 | unique_id_to_result,
468 | run_config.n_best_size,
469 | run_config.max_answer_length,
470 | sequence_added_tokens,
471 | run_config.sentence_boundary_heuristic,
472 | run_config.full_sentence_heuristic,
473 | run_config.shared_sentence_heuristic, )
474 | prelim_predictions = sorted(list(set(prelim_predictions)),
475 | key=lambda x: (x.start_logit_cause + x.end_logit_cause +
476 | x.start_logit_effect + x.end_logit_effect),
477 | reverse=True)
478 |
479 | nbest = get_predictions(prelim_predictions, run_config.n_best_size, features, example)
480 |
481 | # In very rare edge cases we could have no valid predictions. So we
482 | # just create a none prediction in this case to avoid failure.
483 | if not nbest:
484 | nbest.append(_NbestPrediction(text_cause="empty", start_logit_cause=0.0, end_logit_cause=0.0,
485 | text_effect="empty", start_logit_effect=0.0, end_logit_effect=0.0,
486 | start_index_effect=0, end_index_effect=0,
487 | start_index_cause=0, end_index_cause=0))
488 |
489 | assert len(nbest) >= 1
490 |
491 | total_scores = []
492 | best_non_null_entry = None
493 | for entry in nbest:
494 | total_scores.append(entry.start_logit_cause + entry.end_logit_cause +
495 | entry.start_logit_effect + entry.end_logit_effect)
496 | if not best_non_null_entry:
497 | if entry.text_cause and entry.text_effect:
498 | best_non_null_entry = entry
499 |
500 | probabilities = _compute_softmax(total_scores)
501 |
502 | nbest_json = []
503 | current_example_spans = []
504 | for (i, entry) in enumerate(nbest):
505 | output = collections.OrderedDict()
506 | output["text"] = example.context_text
507 | output["probability"] = probabilities[i]
508 | output["cause_text"] = entry.text_cause
509 | output["cause_start_index"] = entry.start_index_cause
510 | output["cause_end_index"] = entry.end_index_cause
511 | output["cause_start_score"] = entry.start_logit_cause
512 | output["cause_end_score"] = entry.end_logit_cause
513 | output["effect_text"] = entry.text_effect
514 | output["effect_start_score"] = entry.start_logit_effect
515 | output["effect_end_score"] = entry.end_logit_effect
516 | output["effect_start_index"] = entry.start_index_effect
517 | output["effect_end_index"] = entry.end_index_effect
518 | new_span = SpanCombination(start_cause=entry.start_index_cause, end_cause=entry.end_index_cause,
519 | start_effect=entry.start_index_effect, end_effect=entry.end_index_effect)
520 | output["is_new"] = all([new_span != other for other in current_example_spans])
521 | nbest_json.append(output)
522 | current_example_spans.append(new_span)
523 |
524 | assert len(nbest_json) >= 1
525 | if suffix_index > 0:
526 | suffix_index -= 1
527 | all_predictions[example.example_id] = {"text": nbest_json[suffix_index]["text"],
528 | "cause_text": nbest_json[suffix_index]["cause_text"],
529 | "effect_text": nbest_json[suffix_index]["effect_text"]}
530 | all_nbest_json[example.example_id] = nbest_json
531 |
532 | output_prediction_file = output_dir / "predictions.json"
533 | csv_output_prediction_file = output_dir / "predictions.csv"
534 | output_nbest_file = output_dir / "nbest_predictions.json"
535 |
536 | logger.info("Writing predictions to: %s" % output_prediction_file)
537 | with open(output_prediction_file, "w") as writer:
538 | writer.write(json.dumps(all_predictions, indent=4) + "\n")
539 |
540 | with open(csv_output_prediction_file, "w", encoding='utf-8', newline='') as writer:
541 | csv_writer = csv.writer(writer, delimiter=';')
542 | csv_writer.writerow(['Index', 'Text', 'Cause', 'Effect'])
543 | for (example_id, prediction) in all_predictions.items():
544 | csv_writer.writerow([example_id, prediction['text'], prediction['cause_text'], prediction['effect_text']])
545 |
546 | logger.info("Writing nbest to: %s" % output_nbest_file)
547 | with open(output_nbest_file, "w") as writer:
548 | writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
549 |
550 | return all_predictions
551 |
552 |
553 | class SpanCombination:
554 | def __init__(self, start_cause: int, end_cause: int, start_effect: int, end_effect: int):
555 | self.start_cause = start_cause
556 | self.start_effect = start_effect
557 | self.end_cause = end_cause
558 | self.end_effect = end_effect
559 |
560 | def __eq__(self, other):
561 | overlapping_cause = (self.start_cause <= other.start_cause <= self.end_cause) or \
562 | (self.start_cause <= other.end_cause <= self.end_cause) or \
563 | (self.start_effect <= other.start_cause <= self.end_effect) or \
564 | (self.start_effect <= other.end_cause <= self.end_effect)
565 | overlapping_effect = (self.start_effect <= other.start_effect <= self.end_effect) or \
566 | (self.start_effect <= other.end_effect <= self.end_effect) or \
567 | (self.start_cause <= other.start_effect <= self.end_cause) or \
568 | (self.start_cause <= other.end_effect <= self.end_cause)
569 | return overlapping_cause and overlapping_effect
570 |
571 |
572 | def _get_best_indexes(logits, n_best_size) -> List[int]:
573 | """Get the n-best logits from a list."""
574 | index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
575 |
576 | best_indexes = []
577 | for i in range(len(index_and_score)):
578 | if i >= n_best_size:
579 | break
580 | best_indexes.append(index_and_score[i][0])
581 | return best_indexes
582 |
583 |
584 | def _compute_softmax(scores) -> List[float]:
585 | """Compute softmax probability over raw logits."""
586 | if not scores:
587 | return []
588 |
589 | max_score = None
590 | for score in scores:
591 | if max_score is None or score > max_score:
592 | max_score = score
593 |
594 | exp_scores = []
595 | total_sum = 0.0
596 | for score in scores:
597 | x = math.exp(score - max_score)
598 | exp_scores.append(x)
599 | total_sum += x
600 |
601 | probabilities = []
602 | for score in exp_scores:
603 | probabilities.append(score / total_sum)
604 | return probabilities
605 |
--------------------------------------------------------------------------------