├── .gitignore ├── LICENSE.txt ├── README.md ├── data ├── instrdt │ ├── test_ids.txt │ ├── tools │ │ ├── extract.py │ │ ├── preprocess.py │ │ └── utils.py │ ├── train_ids.txt │ └── valid_ids.txt └── rstdt │ ├── .gitkeep │ └── sample.json ├── models └── .gitkeep ├── requirements.txt ├── scripts ├── generator_rstdt.sh └── run_shift_reduce_v1.deberta-base.1e-5.sh └── src ├── average_ckpt.py ├── data ├── batch.py ├── datamodule.py ├── dataset.py ├── doc.py ├── edu.py ├── rstdt_relation.py ├── tree.py └── utils.py ├── metrics ├── __init__.py ├── original_parseval.py ├── parseval.py └── rst_parseval.py ├── models ├── classifier │ ├── __init__.py │ ├── classifier_base.py │ ├── linear.py │ ├── shift_reduce_classifier_base.py │ ├── shift_reduce_classifier_v1.py │ ├── shift_reduce_classifier_v2.py │ ├── shift_reduce_classifier_v3.py │ ├── top_down_classifier_base.py │ ├── top_down_classifier_v1.py │ └── top_down_classifier_v2.py ├── encoder │ ├── __init__.py │ ├── bert_encoder.py │ └── encoder.py └── parser │ ├── __init__.py │ ├── organization_feature.py │ ├── parser_base.py │ ├── shift_reduce_parser_base.py │ ├── shift_reduce_parser_v1.py │ ├── shift_reduce_parser_v2.py │ ├── shift_reduce_parser_v3.py │ ├── shift_reduce_state.py │ ├── top_down_parser_base.py │ ├── top_down_parser_v1.py │ ├── top_down_parser_v2.py │ └── utils.py ├── parse.py ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | SOFTWARE LICENSE AGREEMENT FOR EVALUATION 2 | 3 | This SOFTWARE EVALUATION LICENSE AGREEMENT (this "Agreement") is a legal contract between a person who uses or otherwise accesses or installs the Software ("User(s)"), and Nippon Telegraph and Telephone corporation ("NTT"). 4 | READ THE TERMS AND CONDITIONS OF THIS AGREEMENT CAREFULLY BEFORE INSTALLING OR OTHERWISE ACCESSING OR USING NTT'S PROPRIETARY SOFTWARE ACCOMPANIED BY THIS AGREEMENT (the "SOFTWARE"). THE SOFTWARE IS COPYRIGHTED AND IT IS LICENSED TO USER UNDER THIS AGREEMENT, NOT SOLD TO USER. BY INSTALLING OR OTHERWISE ACCESSING OR USING THE SOFTWARE, USER ACKNOWLEDGES THAT USER HAS READ THIS AGREEMENT, THAT USER UNDERSTANDS IT, AND THAT USER ACCEPTS AND AGREES TO BE BOUND BY ITS TERMS. IF AT ANY TIME USER IS NOT WILLING TO BE BOUND BY THE TERMS OF THIS AGREEMENT, USER SHOULD TERMINATE THE INSTALLATION PROCESS, IMMEDIATELY CEASE AND REFRAIN FROM ACCESSING OR USING THE SOFTWARE AND DELETE ANY COPIES USER MAY HAVE. THIS AGREEMENT REPRESENTS THE ENTIRE AGREEMENT BETWEEN USER AND NTT CONCERNING THE SOFTWARE. 5 | 6 | 7 | BACKGROUND 8 | A. NTT is the owner of all rights, including all patent rights, copyrights and trade secret rights, in and to the Software and related documentation listed in Exhibit A to this Agreement. 9 | B. User wishes to obtain a royalty free license to use the Software to enable User to evaluate, and NTT wishes to grant such a license to User, pursuant and subject to the terms and conditions of this Agreement. 10 | C. As a condition to NTT's provision of the Software to User, NTT has required User to execute this Agreement. 11 | In consideration of these premises, and the mutual promises and conditions in this Agreement, the parties hereby agree as follows: 12 | 1. Grant of Evaluation License. NTT hereby grants to User, and User hereby accepts, under the terms and conditions of this Agreement, a royalty free, nontransferable and nonexclusive license to use the Software internally for the purposes of testing, analyzing, and evaluating the methods or mechanisms as shown in the research paper submitted by NTT to a certain academy. User may make a reasonable number of backup copies of the Software solely for User's internal use pursuant to the license granted in this Section 1. 13 | 2. Shipment and Installation. NTT will ship or deliver the Software by any method that NTT deems appropriate. User shall be solely responsible for proper installation of the Software. 14 | 3. Term. This Agreement is effective whichever is earlier (i) upon User's acceptance of the Agreement, or (ii) upon User's installing, accessing, and using the Software, even if User has not expressly accepted this Agreement. Without prejudice to any other rights, NTT may terminate this Agreement without notice to User (i) if User breaches or fails to comply with any of the limitations or other requirements described herein, and (ii) if NTT receives a notice from the academy stating that the research paper would not be published, and in any such case User agrees that NTT may, in addition to any other remedies it may have at law or in equity, remotely disable the Software. User may terminate this Agreement at any time by Users decision to terminate the Agreement to NTT and ceasing use of the Software. Upon any termination or expiration of this Agreement for any reason, User agrees to uninstall the Software and either return to NTT the Software and all copies thereof, or to destroy all such materials and provide written verification of such destruction to NTT. 15 | 4. Proprietary Rights 16 | (a) The Software is the valuable, confidential, and proprietary property of NTT, and NTT shall retain exclusive title to this property both during the term and after the termination of this Agreement. Without limitation, User acknowledges that all patent rights, copyrights and trade secret rights in the Software shall remain the exclusive property of NTT at all times. User shall use not less than reasonable care in safeguarding the confidentiality of the Software. 17 | (b) USER SHALL NOT, IN WHOLE OR IN PART, AT ANY TIME DURING THE TERM OF OR AFTER THE TERMINATION OF THIS AGREEMENT: (i)?SELL, ASSIGN, LEASE, DISTRIBUTE, OR OTHERWISE TRANSFER THE SOFTWARE TO ANY THIRD PARTY; (ii) EXCEPT AS OTHERWISE PROVIDED HEREIN, COPY OR REPRODUCE THE SOFTWARE IN ANY MANNER; (iii) DISCLOSE THE SOFTWARE TO ANY THIRD PARTY, EXCEPT TO USER'S EMPLOYEES WHO REQUIRE ACCESS TO THE SOFTWARE FOR THE PURPOSES OF THIS AGREEMENT; (iv) MODIFY, DISASSEMBLE, DECOMPILE, REVERSE ENGINEER OR TRANSLATE THE SOFTWARE; OR (v) ALLOW ANY PERSON OR ENTITY TO COMMIT ANY OF THE ACTIONS DESCRIBED IN (i) THROUGH (iv) ABOVE. 18 | (c) User shall take appropriate action, by instruction, agreement, or otherwise, with respect to its employees permitted under this Agreement to have access to the Software to ensure that all of User's obligations under this Section 4 shall be satisfied. 19 | 5. Indemnity. User shall defend, indemnify and hold harmless NTT, its agents and employees, from any loss, damage, or liability arising in connection with User's improper or unauthorized use of the Software. NTT SHALL HAVE THE SOLE RIGHT TO CONDUCT DEFEND ANY ACTTION RELATING TO THE SOFTWARE. 20 | 6. Disclaimer. THE SOFTWARE IS LICENSED TO USER "AS IS," WITHOUT ANY TRAINING, MAINTENANCE, OR SERVICE OBLIGATIONS WHATSOEVER ON THE PART OF NTT. NTT MAKES NO EXPRESS OR IMPLIED WARRANTIES OF ANY TYPE WHATSOEVER, INCLUDING WITHOUT LIMITATION THE IMPLIED WARRANTIES OF MERCHANTABILITY, OF FITNESS FOR A PARTICULAR PURPOSE AND OF NON-INFRINGEMENT ON COPYRIGHT OR ANY OTHER RIGHT OF THIRD PARTIES. USER ASSUMES ALL RISKS ASSOCIATED WITH ITS USE OF THE SOFTWARE, INCLUDING WITHOUT LIMITATION RISKS RELATING TO QUALITY, PERFORMANCE, DATA LOSS, AND UTILITY IN A PRODUCTION ENVIRONMENT. 21 | 7. Limitation of Liability. IN NO EVENT SHALL NTT BE LIABLE TO USER OR TO ANY THIRD PARTY FOR ANY INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING BUT NOT LIMITED TO DAMAGES FOR PERSONAL INJURY, PROPERTY DAMAGE, LOST PROFITS, OR OTHER ECONOMIC LOSS, ARISING IN CONNECTION WITH USER'S USE OF OR INABILITY TO USE THE SOFTWARE, IN CONNECTION WITH NTT'S PROVISION OF OR FAILURE TO PROVIDE SERVICES PERTAINING TO THE SOFTWARE, OR AS A RESULT OF ANY DEFECT IN THE SOFTWARE. THIS DISCLAIMER OF LIABILITY SHALL APPLY REGARDLESS OF THE FORM OF ACTION THAT MAY BE BROUGHT AGAINST NTT, WHETHER IN CONTRACT OR TORT, INCLUDING WITHOUT LIMITATION ANY ACTION FOR NEGLIGENCE. USER'S SOLE REMEDY IN THE EVENT OF ANY BREACH OF THIS AGREEMENT BY NTT SHALL BE TERMINATION PURSUANT TO SECTION 3. 22 | 8. No Assignment or Sublicense. Neither this Agreement nor any right or license under this Agreement, nor the Software, may be sublicensed, assigned, or otherwise transferred by User without NTT's prior written consent. 23 | 9. General 24 | (a) If any provision, or part of a provision, of this Agreement is or becomes illegal, unenforceable, or invalidated, by operation of law or otherwise, that provision or part shall to that extent be deemed omitted, and the remainder of this Agreement shall remain in full force and effect. 25 | (b) This Agreement is the complete and exclusive statement of the agreement between the parties with respect to the subject matter hereof, and supersedes all written and oral contracts, proposals, and other communications between the parties relating to that subject matter. 26 | (c) Subject to Section 8, this Agreement shall be binding on, and shall inure to the benefit of, the respective successors and assigns of NTT and User. 27 | (d) If either party to this Agreement initiates a legal action or proceeding to enforce or interpret any part of this Agreement, the prevailing party in such action shall be entitled to recover, as an element of the costs of such action and not as damages, its attorneys' fees and other costs associated with such action or proceeding. 28 | (e) This Agreement shall be governed by and interpreted under the laws of Japan, without reference to conflicts of law principles. All disputes arising out of or in connection with this Agreement shall be finally settled by arbitration in Tokyo in accordance with the Commercial Arbitration Rules of the Japan Commercial Arbitration Association. The arbitration shall be conducted by three (3) arbitrators and in Japanese. The award rendered by the arbitrators shall be final and binding upon the parties. Judgment upon the award may be entered in any court having jurisdiction thereof. 29 | (f) NTT shall not be liable to the User or to any third party for any delay or failure to perform NTT's obligation set forth under this Agreement due to any cause beyond NTTs reasonable control. 30 | 31 | 32 | EXHIBIT A 33 | The software and related documentation in this repository. 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Implemenation of Neural RST Parser 2 | 3 | **A Simple and Strong Baseline for End-to-End Neural RST-style Discourse Parsing** 4 | 5 | Paper: [URL for arXiv](https://arxiv.org/abs/2210.08355) 6 | 7 | Abstract: 8 | To promote and further develop RST-style discourse parsing models, 9 | we need a strong baseline that can be regarded as a reference for reporting reliable experimental results. 10 | This paper explores a strong baseline by integrating existing simple parsing strategies, 11 | top-down and bottom-up, with various transformer-based pre-trained language models. 12 | The experimental results obtained from two benchmark datasets demonstrate that 13 | the parsing performance strongly relies on the pre-trained language models rather than the parsing strategies. 14 | In particular, the bottom-up parser achieves large performance gains compared to the current best parser when employing DeBERTa. 15 | We further reveal that language models with a span-masking scheme especially boost 16 | the parsing performance through our analysis within intra- and multi-sentential parsing, and nuclearity prediction. 17 | 18 | 19 | ## Setup 20 | 1. prepare python environment with conda 21 | 2. clone this repository 22 | ```bash 23 | git clone https://github.com/nttcslab-nlp/RSTParser_EMNLP22 24 | cd RSTParser_EMNLP22 25 | ``` 26 | 3. manually install pytorch to enable GPU support 27 | ```bash 28 | conda install pytorch cudatoolkit=XXX -c pytorch 29 | conda install torchtext -c pytorch 30 | ``` 31 | 4. install dependencies 32 | ```bash 33 | pip install -r requirements.txt 34 | ``` 35 | 36 | ## Preprocess for dataset 37 | ### RSTDT 38 | If you have the RSTDT dataset that has been preprocessed 39 | by [Heilman's script](https://github.com/EducationalTestingService/rstfinder.git), 40 | you can use it without doing anything. 41 | 42 | Each data of Heilman's has following elements. 43 | ``` 44 | - doc_id 45 | - edu_start_indices 46 | - edu_starts_paragraph 47 | - edu_strings 48 | - path_basename # not necessary 49 | - pos_tags # not necessary 50 | - rst_tree 51 | - syntax_trees # not necessary 52 | - token_tree_positions # not necessary 53 | - tokens 54 | ``` 55 | 56 | There is a sample at `data/rstdt/sample.json` 57 | 58 | 59 | ### Instr-DT 60 | The instr-DT dataset contains data that has multiple trees for one document. 61 | We treat those data by combining the multiple trees into single tree 62 | with the "Nucleus-Nucleus" and "topic-change?" labels. 63 | ```bash 64 | cd data/instrdt 65 | git clone https://github.com/EducationalTestingService/rstfinder.git 66 | rm rstfinder/rstfinder/__init__.py # this requires zpar package 67 | ln -nfs $PWD/rstfinder/rstfinder/ tools/ 68 | 69 | python tools/preprocess.py \ 70 | --input-dir PATH/TO/instr-discourse-data/discourse_annotation/ \ 71 | --output-file all.json \ 72 | --joint-with-nn 73 | for tgt in train valid test; do 74 | python tools/extract.py \ 75 | --src all.json \ 76 | --tgt ${tgt}.json \ 77 | --ids ${tgt}_ids.txt 78 | done 79 | ``` 80 | 81 | 82 | ## Train and test 83 | ```bash 84 | bash scripts/run_shift_reduce_v1.deberta-base.1e-5.sh 85 | ``` 86 | This script was genereated by `scripts/generator_rstdt.sh`. 87 | Some enviroment variables (`CUDA_VISIBLE_DEVICES`) are hard codede in the script. 88 | 89 | Although the maximum is 20 epochs, it converges in about 5 epochs. 90 | (Training time is about 1h/epoch with GeForce RTX 3090) 91 | 92 | Models are saved into `./models/rstdt/shift_reduce_v1.deberta-base.1e-5/version_?/checkpoints/`. 93 | 94 | Saved models are followings 95 | ``` 96 | PATH/TO/checkpoints/ 97 | - epoch=3-step=?????.ckpt # saved in training process 98 | - epoch=3-step=?????.ckpt # saved in training process 99 | - epoch=4-step=?????.ckpt # saved in training process 100 | - last.ckpt # saved at the end of training process 101 | - best.ctpt # selected the best model by validation score in evaluation process 102 | - average.ckpt # output of checkpoint weight averaging (CPA) at evaluation process 103 | ``` 104 | 105 | 106 | ## Test only single checkpoint 107 | ```bash 108 | python src/test.py --ckpt-path PATH/TO/CKPT --save-dir PATH/TO/TREES/ --metrics OriginalParseval 109 | ``` 110 | 111 | 112 | ## Parse raw document (sequence of EDUs) 113 | TBU 114 | 115 | 116 | ## Performance and Checkpoints 117 | These are results evaluated by OriginalParseval. 118 | (scores are a bit different from the paper due to model retraining.) 119 | 120 | |Model |LM |CPA|Span |Nuc. |Rel. |Ful. |ckpt| 121 | |------------|-------|:-:|----:|----:|----:|----:|----| 122 | |Shift-Reduce|DeBERTa| | 77.0| 67.2| 57.5| 55.5| [Google Drive](https://drive.google.com/file/d/14CPKfvaGolg5Kd0smBpb2B4un3ZBaJgQ/view?usp=share_link)| 123 | |Shift-Reduce|DeBERTa| x | 78.3| 69.0| 58.2| 56.2| [Google Drive](https://drive.google.com/file/d/1Xt63FI-VfovyjtobM6K5oTBsDq_14SI-/view?usp=share_link)| 124 | 125 | 126 | ## Reference 127 | ```text 128 | TBU 129 | ``` 130 | 131 | 132 | ## LICENSE 133 | 134 | This software is released under the NTT License, see `LICENSE.txt`. 135 | 136 | According to the license, it is not allowed to create pull requests. Please feel free to send issues. 137 | -------------------------------------------------------------------------------- /data/instrdt/test_ids.txt: -------------------------------------------------------------------------------- 1 | sunset1-paint.6.1 2 | sunset1-fabric.5.1 3 | sunset1-awp.11.1 4 | sunset1-fabric.21.1 5 | sunset1-paint.14.1 6 | sunset1-panel.9.1 7 | sunset2-5-parquet.12.1 8 | sunset1-prep.3.1 9 | sunset1-paint.8.1 10 | sunset1-fabric.12.1 11 | sunset1-awp.1.1 12 | sunset1-paint.2.1 13 | sunset1-paint.17.1 14 | sunset1-fabric.10.1 15 | sunset2-3-ceramic.37.1 16 | sunset2-3-ceramic.17.1 17 | installing-wood-tiles.tx.1 18 | sunset2-3-ceramic.6.1 19 | sunset2-3-ceramic.1.1 20 | sunset1-paint.10.1 21 | sunset2-3-ceramic.27.1 22 | sunset1-fabric.24.1 23 | sunset2-4-resil.8.1 24 | sunset1-awp.27.1 25 | sunset1-awp.4.1 26 | -------------------------------------------------------------------------------- /data/instrdt/tools/extract.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import json 4 | 5 | 6 | def main(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument( 9 | "--src", 10 | type=Path, 11 | required=True, 12 | help="input json file", 13 | ) 14 | parser.add_argument( 15 | "--ids", 16 | type=Path, 17 | required=True, 18 | help="text file contains doc_ids", 19 | ) 20 | parser.add_argument( 21 | "--tgt", 22 | type=Path, 23 | required=True, 24 | help="output json file", 25 | ) 26 | config = parser.parse_args() 27 | 28 | doc_ids = None 29 | with open(config.ids) as f: 30 | doc_ids = {line.strip(): None for line in f} 31 | 32 | assert len(doc_ids) > 0 33 | 34 | doc_id2data = {} 35 | with open(config.src) as f: 36 | for data in json.load(f): 37 | doc_id = data["doc_id"] 38 | if doc_id in doc_ids: 39 | assert doc_id not in doc_id2data, "Duplicate doc_ids" 40 | doc_id2data[doc_id] = data 41 | 42 | dataset = [doc_id2data[doc_id] for doc_id in doc_ids] 43 | with open(config.tgt, "w") as f: 44 | json.dump(dataset, f) 45 | 46 | return 47 | 48 | 49 | if __name__ == "__main__": 50 | main() 51 | -------------------------------------------------------------------------------- /data/instrdt/tools/preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | from nltk import ParentedTree, Tree 4 | import json 5 | 6 | from rstfinder.reformat_rst_trees import reformat_rst_tree 7 | from utils import ( 8 | extract_edus_from_rst_tree_str, 9 | remove_edu_from_rst_tree_str, 10 | divide_rst_tree_str, 11 | fix_relation_label, 12 | binarize, 13 | rst_to_attach, 14 | attach_to_rst, 15 | re_assign_edu_idx, 16 | TREE_PRINT_MARGIN, 17 | ) 18 | 19 | 20 | def main(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument( 23 | "--input-dir", 24 | required=True, 25 | type=Path, 26 | help="path to original instr-dt dataset (contains .out files)", 27 | ) 28 | parser.add_argument( 29 | "--output-file", 30 | required=True, 31 | type=Path, 32 | help="path to output json file", 33 | ) 34 | parser.add_argument( 35 | "--joint-with-nn", 36 | action="store_true", 37 | help="joint multiple trees in one document to single tree", 38 | ) 39 | args = parser.parse_args() 40 | 41 | dataset = [] 42 | num_docs = 0 43 | for file_path in args.input_dir.glob("*.out"): 44 | data = read_data(file_path, args.joint_with_nn) 45 | dataset.extend(data) 46 | num_docs += 1 47 | 48 | # check size 49 | assert num_docs == 176 50 | if args.joint_with_nn: 51 | assert len(dataset) == 176 52 | else: 53 | assert len(dataset) == 320 54 | 55 | save_json(args.output_file, dataset) 56 | return 57 | 58 | 59 | def read_data(file_path, joint_with_nn=False): 60 | path_basename = file_path.name 61 | doc_id = file_path.name.rstrip(".out") 62 | edu_file = file_path.with_suffix(".out.edus") 63 | dis_file = file_path.with_suffix(".out.dis") 64 | 65 | edu_strings = read_edus(edu_file) 66 | rst_trees, _edu_strings = read_dis(dis_file, joint_with_nn) 67 | assert len(edu_strings) == len(_edu_strings) 68 | 69 | edu_starts_sentence = [True if edu.startswith("") else False for edu in _edu_strings] 70 | edu_starts_paragraph = [False] * len(edu_strings) # no annotated 71 | 72 | data = [] 73 | edu_offset = 0 74 | for idx, rst_tree in enumerate(rst_trees): 75 | bi_rst_tree = binarize(rst_tree) 76 | attach_tree = rst_to_attach(bi_rst_tree) 77 | _bi_rst_tree = attach_to_rst(attach_tree) 78 | assert _bi_rst_tree == bi_rst_tree 79 | n_edus = len(rst_tree.leaves()) 80 | assert n_edus == len(edu_strings[edu_offset : edu_offset + n_edus]) 81 | data.append( 82 | { 83 | "path_basename": path_basename, 84 | "doc_id": doc_id + ".{}".format(idx + 1), 85 | "rst_tree": rst_tree.pformat(margin=TREE_PRINT_MARGIN), 86 | "binarised_rst_tree": bi_rst_tree.pformat(margin=TREE_PRINT_MARGIN), 87 | "attach_tree": attach_tree.pformat(margin=TREE_PRINT_MARGIN), 88 | "edu_strings": edu_strings[edu_offset : edu_offset + n_edus], 89 | "edu_starts_sentence": edu_starts_sentence[edu_offset : edu_offset + n_edus], 90 | "edu_starts_paragraph": edu_starts_paragraph[edu_offset : edu_offset + n_edus], 91 | } 92 | ) 93 | edu_offset = edu_offset + n_edus 94 | 95 | assert len(edu_strings) == edu_offset 96 | 97 | return data 98 | 99 | 100 | def read_edus(file_path): 101 | edus = [] 102 | with open(file_path) as f: 103 | for line in f: 104 | edu = line.strip() 105 | edus.append(edu) 106 | 107 | return edus 108 | 109 | 110 | def read_dis(file_path, joint_with_NN=False): 111 | trees = [] 112 | with open(file_path) as f: 113 | rst_tree_str = f.read().strip() 114 | edus = extract_edus_from_rst_tree_str(rst_tree_str) 115 | rst_tree_str = remove_edu_from_rst_tree_str(rst_tree_str) 116 | 117 | for _rst_tree_str in divide_rst_tree_str(rst_tree_str): 118 | rst_tree = ParentedTree.fromstring(_rst_tree_str) 119 | reformat_rst_tree(rst_tree) 120 | rst_tree = Tree.convert(rst_tree) 121 | trees.append(rst_tree) 122 | 123 | if joint_with_NN and len(trees) != 1: 124 | for tree in trees: 125 | tree.set_label("nucleus:topic-change?") 126 | 127 | tree = Tree("ROOT", trees) 128 | re_assign_edu_idx(tree) 129 | trees = [tree] 130 | 131 | for tree in trees: 132 | fix_relation_label(tree) 133 | 134 | return trees, edus 135 | 136 | 137 | def save_json(file_path, dataset): 138 | print('save into "{}" (consists of {} trees)'.format(file_path, len(dataset))) 139 | with open(file_path, "w") as f: 140 | json.dump(dataset, f) 141 | 142 | return 143 | 144 | 145 | if __name__ == "__main__": 146 | main() 147 | -------------------------------------------------------------------------------- /data/instrdt/tools/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from nltk import Tree 3 | 4 | TREE_PRINT_MARGIN = 1000000000 5 | 6 | 7 | def divide_rst_tree_str(rst_tree_str: str): 8 | depth = 0 9 | start, end = 0, None 10 | for i, c in enumerate(rst_tree_str): 11 | if c == "\n": 12 | continue 13 | 14 | if c == "(": 15 | depth += 1 16 | elif c == ")": 17 | depth -= 1 18 | 19 | assert depth >= 0 20 | 21 | if depth == 0: 22 | end = i + 1 23 | yield rst_tree_str[start:end] 24 | start = end 25 | 26 | 27 | def extract_edus_from_rst_tree_str(rst_tree_str: str): 28 | edus = re.findall(r"\(text (.*?)\)", rst_tree_str) 29 | return edus 30 | 31 | 32 | def remove_edu_from_rst_tree_str(rst_tree_str: str): 33 | rst_tree_str = re.sub(r"(.*?)", "EDU_TEXT", rst_tree_str) 34 | return rst_tree_str 35 | 36 | 37 | def fix_relation_label(tree: Tree): 38 | for tp in tree.treepositions(): 39 | node = tree[tp] 40 | if not isinstance(node, Tree): 41 | continue 42 | if node.label() == "text": 43 | continue 44 | 45 | childs = [c for c in node] 46 | if len(childs) == 1 and childs[0].label() == "text": 47 | continue 48 | 49 | rel_labels = [c.label().split(":", maxsplit=1)[1] for c in childs] 50 | nuc_labels = [c.label().split(":", maxsplit=1)[0] for c in childs] 51 | assert all([l == rel_labels[0] for l in rel_labels]) 52 | 53 | if all([n == "nucleus" for n in nuc_labels]): 54 | # N-N 55 | continue 56 | else: 57 | for n, child in zip(nuc_labels, childs): 58 | if n == "nucleus": 59 | child.set_label("nucleus:span") 60 | return 61 | 62 | 63 | def binarize(tree: Tree, top: bool = True): 64 | if top: 65 | assert tree.label() == "ROOT" 66 | if len(tree) == 1: 67 | # End of recursion 68 | return tree 69 | elif len(tree) == 2: 70 | # Binary structure 71 | left_tree = binarize(tree[0], top=False) 72 | right_tree = binarize(tree[1], top=False) 73 | else: 74 | # Non-Binary structure 75 | labels = [tree[i].label() for i in range(len(tree))] 76 | is_polynuclear = all(map(lambda x: x == labels[0], labels)) 77 | if is_polynuclear: 78 | # Polynuclear relation label such as: 79 | # same-unit, list, etc... 80 | # -> convert to right heavy structure 81 | left_tree = binarize(tree[0], top=False) 82 | right_tree = binarize( 83 | Tree(tree[0].label(), [tree[i] for i in range(1, len(tree))]), top=False 84 | ) 85 | else: 86 | # Non Binary structure without Polynuclear label 87 | # S/N/S -> left heavy 88 | left_tree = binarize(Tree("nucleus:span", [tree[0], tree[1]]), top=False) 89 | right_tree = binarize(tree[2], top=False) 90 | 91 | return Tree(tree.label(), [left_tree, right_tree]) 92 | 93 | 94 | def is_binary(tree: Tree): 95 | if not isinstance(tree, Tree): 96 | return True 97 | elif len(tree) > 2: 98 | return False 99 | else: 100 | return all([is_binary(child) for child in tree]) 101 | 102 | 103 | def attach_to_rst(tree: Tree, label: str = "ROOT"): 104 | if len(tree) == 1: 105 | return Tree(label, [tree]) 106 | nuc_label, rel_label = tree.label().split(":", maxsplit=1) 107 | if len(nuc_label.split("-")) == 1: 108 | raise ValueError("Invalid nucleus label: {}".format(nuc_label)) 109 | 110 | l_nuc, r_nuc = nuc_label.split("-") 111 | if nuc_label == "nucleus-satellite": 112 | l_rel = "span" 113 | r_rel = rel_label 114 | elif nuc_label == "satellite-nucleus": 115 | l_rel = rel_label 116 | r_rel = "span" 117 | elif nuc_label == "nucleus-nucleus": 118 | l_rel = r_rel = rel_label 119 | else: 120 | raise ValueError("Unkwon Nucleus label: {}".format(nuc_label)) 121 | 122 | return Tree( 123 | label, 124 | [ 125 | attach_to_rst(tree[0], ":".join([l_nuc, l_rel])), 126 | attach_to_rst(tree[1], ":".join([r_nuc, r_rel])), 127 | ], 128 | ) 129 | 130 | 131 | def rst_to_attach(rst_tree: Tree): 132 | if len(rst_tree) == 1: 133 | return rst_tree[0] 134 | 135 | l_nuc, l_rel = rst_tree[0].label().split(":", maxsplit=1) 136 | r_nuc, r_rel = rst_tree[1].label().split(":", maxsplit=1) 137 | nuc = "-".join([l_nuc, r_nuc]) 138 | rel = l_rel if l_nuc == "satellite" else r_rel 139 | label = ":".join([nuc, rel]) 140 | return Tree(label, [rst_to_attach(child) for child in rst_tree]) 141 | 142 | 143 | def re_assign_edu_idx(tree: Tree): 144 | for idx, tp in enumerate(tree.treepositions("leave")): 145 | lp = tp[:-1] 146 | tree[lp] = Tree("text", [str(idx)]) 147 | 148 | 149 | def is_nucleus(node: Tree): 150 | label = node.label() 151 | if label in ["text", "ROOT"]: 152 | return False 153 | 154 | nuc = label.split(":")[0] 155 | assert nuc in ["nucleus", "satellite"] 156 | return nuc == "nucleus" 157 | 158 | 159 | def is_satellite(node: Tree): 160 | return not is_nucleus(node) 161 | 162 | 163 | def is_multi_nucleus(node: Tree): 164 | if isinstance(node, str): 165 | return False 166 | label = node.label() 167 | if label in ["text"]: 168 | return False 169 | if len(node) <= 2: 170 | return False 171 | 172 | return True 173 | -------------------------------------------------------------------------------- /data/instrdt/train_ids.txt: -------------------------------------------------------------------------------- 1 | sunset2-3-ceramic.9.1 2 | sunset1-paint.27.1 3 | sunset2-3-ceramic.36.1 4 | sunset1-awp.21.1 5 | sunset2-3-ceramic.22.1 6 | sunset2-4-resil.10.1 7 | sunset1-paint.20.1 8 | sunset2-5-parquet.13.1 9 | sunset1-fabric.22.1 10 | sunset1-paint.7.1 11 | sunset2-3-ceramic.26.1 12 | sunset1-awp.6.1 13 | sunset1-panel.1.1 14 | sunset2-3-ceramic.31.1 15 | sunset1-paint.4.1 16 | sunset1-awp.23.1 17 | sunset1-fabric.25.1 18 | sunset1-awp.2.1 19 | sunset1-awp.26.1 20 | sunset2-3-ceramic.32.1 21 | sunset2-3-ceramic.40.1 22 | sunset2-3-ceramic.11.1 23 | sunset2-3-ceramic.20.1 24 | sunset2-5-parquet.10.1 25 | sunset1-awp.14.1 26 | sunset2-3-ceramic.21.1 27 | sunset1-awp.19.1 28 | sunset2-3-ceramic.23.1 29 | sunset2-3-ceramic.13.1 30 | sunset2-3-ceramic.5.1 31 | sunset2-3-ceramic.38.1 32 | sunset1-paint.18.1 33 | sunset1-fabric.2.1 34 | sunset1-paint.22.1 35 | sunset1-panel.2.1 36 | sunset1-fabric.11.1 37 | sunset1-awp.13.1 38 | sunset1-fabric.1.1 39 | sunset2-3-ceramic.18.1 40 | sunset2-3-ceramic.14.1 41 | sunset2-4-resil.6.1 42 | sunset2-4-resil.7.1 43 | sunset1-awp.15.1 44 | sunset2-5-parquet.7.1 45 | sunset1-panel.15.1 46 | sunset2-5-parquet.6.1 47 | sunset1-awp.3.1 48 | sunset2-5-parquet.3.1 49 | sunset1-fabric.14.1 50 | sunset1-awp.16.1 51 | sunset2-3-ceramic.3.1 52 | sunset1-fabric.19.1 53 | sunset2-3-ceramic.19.1 54 | sunset2-3-ceramic.2.1 55 | sunset2-5-parquet.8.1 56 | sunset1-paint.16.1 57 | sunset1-prep.7.1 58 | sunset2-3-ceramic.16.1 59 | sunset2-3-ceramic.8.1 60 | sunset1-panel.8.1 61 | sunset1-prep.8.1 62 | sunset2-3-ceramic.43.1 63 | sunset1-paint.11.1 64 | sunset1-panel.14.1 65 | sunset2-3-ceramic.39.1 66 | sunset2-5-parquet.11.1 67 | sunset2-3-ceramic.25.1 68 | sunset1-fabric.6.1 69 | sunset2-4-resil.12.1 70 | sunset1-paint.19.1 71 | sunset1-panel.10.1 72 | sunset1-fabric.4.1 73 | sunset2-5-parquet.5.1 74 | sunset1-awp.5.1 75 | sunset1-fabric.17.1 76 | sunset1-paint.1.1 77 | sunset2-5-parquet.4.1 78 | sunset1-fabric.13.1 79 | sunset1-awp.10.1 80 | sunset2-3-ceramic.15.1 81 | sunset1-prep.6.1 82 | sunset1-prep.5.1 83 | sunset1-paint.5.1 84 | sunset1-prep.2.1 85 | sunset2-3-ceramic.10.1 86 | sunset2-5-parquet.2.1 87 | sunset1-awp.25.1 88 | sunset2-3-ceramic.29.1 89 | sunset2-3-ceramic.4.1 90 | sunset2-4-resil.2.1 91 | sunset2-3-ceramic.28.1 92 | sunset1-awp.12.1 93 | sunset2-4-resil.1.1 94 | sunset1-paint.9.1 95 | sunset2-5-parquet.1.1 96 | sunset1-panel.3.1 97 | sunset1-awp.22.1 98 | sunset2-4-resil.9.1 99 | sunset1-awp.18.1 100 | sunset1-paint.23.1 101 | sunset2-4-resil.5.1 102 | sunset2-3-ceramic.30.1 103 | sunset2-3-ceramic.24.1 104 | sunset1-fabric.3.1 105 | sunset1-panel.16.1 106 | sunset1-prep.10.1 107 | sunset1-fabric.8.1 108 | sunset1-paint.21.1 109 | sunset2-4-resil.3.1 110 | sunset1-fabric.7.1 111 | sunset1-panel.7.1 112 | sunset1-prep.1.1 113 | sunset2-3-ceramic.12.1 114 | sunset1-awp.24.1 115 | sunset1-prep.9.1 116 | sunset1-fabric.16.1 117 | sunset1-paint.25.1 118 | sunset1-panel.12.1 119 | sunset2-3-ceramic.33.1 120 | sunset1-paint.26.1 121 | sunset2-4-resil.4.1 122 | sunset1-fabric.23.1 123 | sunset1-panel.4.1 124 | sunset1-fabric.20.1 125 | sunset1-panel.13.1 126 | sunset1-panel.6.1 127 | -------------------------------------------------------------------------------- /data/instrdt/valid_ids.txt: -------------------------------------------------------------------------------- 1 | sunset2-4-resil.11.1 2 | sunset2-3-ceramic.35.1 3 | sunset1-awp.17.1 4 | sunset2-3-ceramic.42.1 5 | sunset2-5-parquet.9.1 6 | sunset1-fabric.15.1 7 | sunset1-panel.5.1 8 | sunset1-fabric.18.1 9 | sunset2-3-ceramic.34.1 10 | sunset1-awp.7.1 11 | sunset1-awp.20.1 12 | sunset1-awp.28.1 13 | sunset1-panel.11.1 14 | sunset1-paint.24.1 15 | sunset2-3-ceramic.41.1 16 | sunset1-paint.3.1 17 | sunset1-fabric.9.1 18 | sunset1-fabric.26.1 19 | sunset1-awp.8.1 20 | sunset1-paint.13.1 21 | sunset2-3-ceramic.7.1 22 | sunset1-prep.4.1 23 | sunset1-awp.9.1 24 | sunset1-paint.12.1 25 | sunset1-paint.15.1 26 | -------------------------------------------------------------------------------- /data/rstdt/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nttcslab-nlp/RSTParser_EMNLP22/391e17067b7930f550f71482fd940c6c551e7c19/data/rstdt/.gitkeep -------------------------------------------------------------------------------- /data/rstdt/sample.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "doc_id": "wsj_xxxx", 4 | "path_basename": "wsj_xxxx.out.edus", 5 | "tokens":[ 6 | ["John", "Smith", ",", "director", "of", "the", "ABC", "company", ",", "was", "named", "CEO", "of", "XYZ", "company", "."], 7 | ["John", "succeeds", "Thomas", "Brown", ",", "who", "is", "retiring", "as", "CEO", ",", "but", "will", "continue", "as", "a", "director", "of", "the", "ABC", "company", "."] 8 | ], 9 | "edu_strings":[ 10 | "John Smith, director of the ABC company, was named CEO of XYZ company.", 11 | "John succeeds Thomas Brown,", 12 | "who is retiring as CEO,", 13 | "but will continue as a director of the ABC company." 14 | ], 15 | "edu_start_indices":[ 16 | [0, 0, 0], 17 | [1, 0, 1], 18 | [1, 5, 2], 19 | [1, 11, 3] 20 | ], 21 | "rst_tree": "(ROOT (nucleus:span (text 0)) (satellite:circumstance (nucleus:span (text 1)) (satellite:elaboration-additional-e (satellite:concession (text 2)) (nucleus:span (text 3)))))", 22 | "edu_starts_paragraph": [true, false, false, false] 23 | } 24 | ] 25 | -------------------------------------------------------------------------------- /models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nttcslab-nlp/RSTParser_EMNLP22/391e17067b7930f550f71482fd940c6c551e7c19/models/.gitkeep -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | nltk 2 | spacy>=3.0.0,<4.0.0 3 | en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.4.0/en_core_web_sm-3.4.0-py3-none-any.whl 4 | torchmetrics==0.8.2 5 | pytorch-lightning == 1.6.3 6 | transformers >= 4.19.2 7 | -------------------------------------------------------------------------------- /scripts/generator_rstdt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | data_dir=./data/rstdt 4 | save_dir=./models/rstdt 5 | num_gpus=1 6 | cuda_devices=0 7 | 8 | for parser in top_down_v1 top_down_v2 shift_reduce_v1 shift_reduce_v2 shift_reduce_v3; do 9 | for lm in bert-base-cased roberta-base xlnet-base-cased spanbert-base-cased electra-base-discriminator mpnet-base deberta-base; do 10 | for lr in 1e-5 2e-4; do 11 | file_path=run_$parser.$lm.$lr.sh 12 | cat << EOF > $file_path 13 | #!/bin/bash 14 | set -x 15 | 16 | # OPTIONS 17 | DATA_DIR=$data_dir 18 | SAVE_DIR=$save_dir 19 | PARSER_TYPE=$parser 20 | BERT_TYPE=$lm 21 | LR=$lr 22 | NUM_GPUS=$num_gpus 23 | export CUDA_VISIBLE_DEVICES=$cuda_devices 24 | EOF 25 | 26 | cat << 'EOF' >> $file_path 27 | for SEED in 0 1 2; do 28 | MODEL_NAME=$PARSER_TYPE.$BERT_TYPE.$LR 29 | VERSION=$SEED 30 | 31 | if [ -d $SAVE_DIR/$MODEL_NAME/version_$SEED/checkpoints ]; then 32 | # checkpoints exist, skip TRAINING 33 | : 34 | else 35 | # RUN TRAINING 36 | python src/train.py \ 37 | --model-type $PARSER_TYPE \ 38 | --bert-model-name $BERT_TYPE \ 39 | --batch-unit-type span_fast \ 40 | --batch-size 5 \ 41 | --accumulate-grad-batches 1 \ 42 | --num-workers 0 \ 43 | --lr $LR \ 44 | --num-gpus $NUM_GPUS \ 45 | --data-dir $DATA_DIR \ 46 | --save-dir $SAVE_DIR \ 47 | --model-name $MODEL_NAME \ 48 | --model-version $SEED \ 49 | --seed $SEED 50 | fi 51 | 52 | 53 | # RUN TEST 54 | if [ -d $SAVE_DIR/$MODEL_NAME/version_$SEED/checkpoints ]; then 55 | python src/test.py \ 56 | --num-workers 0 \ 57 | --data-dir $DATA_DIR \ 58 | --ckpt-dir $SAVE_DIR/$MODEL_NAME/version_$SEED/checkpoints \ 59 | --save-dir $SAVE_DIR/$MODEL_NAME/version_$SEED/trees 60 | else 61 | # No exists checkpoint dir 62 | : 63 | fi 64 | done 65 | EOF 66 | 67 | done 68 | done 69 | done 70 | -------------------------------------------------------------------------------- /scripts/run_shift_reduce_v1.deberta-base.1e-5.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -x 3 | 4 | # OPTIONS 5 | DATA_DIR=./data/rstdt 6 | SAVE_DIR=./models/rstdt 7 | PARSER_TYPE=shift_reduce_v1 8 | BERT_TYPE=deberta-base 9 | LR=1e-5 10 | NUM_GPUS=1 11 | export CUDA_VISIBLE_DEVICES=0 12 | for SEED in 0 1 2; do 13 | MODEL_NAME=$PARSER_TYPE.$BERT_TYPE.$LR 14 | VERSION=$SEED 15 | 16 | if [ -d $SAVE_DIR/$MODEL_NAME/version_$SEED/checkpoints ]; then 17 | # checkpoints exist, skip TRAINING 18 | : 19 | else 20 | # RUN TRAINING 21 | python src/train.py \ 22 | --model-type $PARSER_TYPE \ 23 | --bert-model-name $BERT_TYPE \ 24 | --batch-unit-type span_fast \ 25 | --batch-size 5 \ 26 | --accumulate-grad-batches 1 \ 27 | --num-workers 0 \ 28 | --lr $LR \ 29 | --num-gpus $NUM_GPUS \ 30 | --data-dir $DATA_DIR \ 31 | --save-dir $SAVE_DIR \ 32 | --model-name $MODEL_NAME \ 33 | --model-version $SEED \ 34 | --seed $SEED 35 | fi 36 | 37 | 38 | # RUN TEST 39 | if [ -d $SAVE_DIR/$MODEL_NAME/version_$SEED/checkpoints ]; then 40 | python src/test.py \ 41 | --num-workers 0 \ 42 | --data-dir $DATA_DIR \ 43 | --ckpt-dir $SAVE_DIR/$MODEL_NAME/version_$SEED/checkpoints \ 44 | --save-dir $SAVE_DIR/$MODEL_NAME/version_$SEED/trees 45 | else 46 | # No exists checkpoint dir 47 | : 48 | fi 49 | done 50 | -------------------------------------------------------------------------------- /src/average_ckpt.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | import torch 4 | 5 | 6 | def average_checkpoints(ckpt_path_list): 7 | """modifyid average_checkpoints.py for torch-lightning 8 | https://github.com/pytorch/fairseq/blob/master/scripts/average_checkpoints.py.""" 9 | params_dict = collections.OrderedDict() 10 | params_keys = None 11 | new_ckpt = None 12 | num_ckpts = len(ckpt_path_list) 13 | 14 | for ckpt_path in ckpt_path_list: 15 | ckpt = torch.load(ckpt_path) 16 | 17 | if new_ckpt is None: 18 | new_ckpt = ckpt 19 | 20 | model_params = ckpt["state_dict"] 21 | model_params_keys = list(model_params.keys()) 22 | if params_keys is None: 23 | params_keys = model_params_keys 24 | elif params_keys != model_params_keys: 25 | raise KeyError( 26 | "For checkpoint {}, expected list of params: {}, " 27 | "but found: {}".format(ckpt_path, params_keys, model_params_keys) 28 | ) 29 | 30 | for k in params_keys: 31 | p = model_params[k] 32 | if isinstance(p, torch.HalfTensor): 33 | p = p.float() 34 | if k not in params_dict: 35 | params_dict[k] = p.clone() 36 | # NOTE: clone() is needed in case of p is a shared parameter 37 | else: 38 | params_dict[k] += p 39 | 40 | averaged_params = collections.OrderedDict() 41 | for k, v in params_dict.items(): 42 | averaged_params[k] = v 43 | if averaged_params[k].is_floating_point(): 44 | averaged_params[k].div_(num_ckpts) 45 | else: 46 | averaged_params[k] //= num_ckpts 47 | 48 | new_ckpt["state_dict"] = averaged_params 49 | return new_ckpt 50 | -------------------------------------------------------------------------------- /src/data/batch.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | import transformers 5 | from torch.utils.data._utils.collate import default_collate 6 | 7 | from data.doc import Doc 8 | 9 | 10 | class Batch(object): 11 | def __init__(self, samples: List[dict], unit_type: str): 12 | assert unit_type in ["document", "span", "span_fast"] 13 | if unit_type != "span_fast": 14 | assert len(samples) == 1 15 | self.unit_type = unit_type 16 | 17 | if self.unit_type == "document": 18 | # If self.unit_type == 'document', a unit of a batch is a document. 19 | # Span and label contain all of the parsing elements of documents. 20 | doc: Doc = samples[0]["doc"] # Doc 21 | span = samples[0]["span"] # List[dict] 22 | label = samples[0]["label"] # List[dict] 23 | feat = samples[0]["feat"] # List[dict] 24 | elif self.unit_type == "span": 25 | # If self.unit_type == 'span', a unit of batch is a span. 26 | # Each sample has one element of parsing proceduer. 27 | doc: Doc = samples[0]["doc"] # Doc 28 | span = [samples[0]["span"]] # dict -> List[dict] 29 | label = [samples[0]["label"]] # dict -> List[dict] 30 | feat = [samples[0]["feat"]] # dict -> List[dict] 31 | elif self.unit_type == "span_fast": 32 | # If self.unit_type == 'span_fast', a unit of batch is a span. 33 | # Each sample has multiple element of parsing proceduer. 34 | doc: Doc = samples[0]["doc"] # Doc 35 | span = samples[0]["span"] # List[dict] 36 | label = samples[0]["label"] # List[dict] 37 | feat = samples[0]["feat"] # List[dict] 38 | else: 39 | raise ValueError("Invalid batch unit_type ({})".format(unit_type)) 40 | 41 | self.doc = doc 42 | self.inputs = doc.inputs 43 | self.span = span 44 | self.feat = feat 45 | self.label = default_collate(label) 46 | 47 | def __len__(self): 48 | return len(self.span) 49 | 50 | def __repr__(self): 51 | return "Batch(doc: {}, span: {}, label: {})".format(self.doc, self.span, self.label) 52 | 53 | def pin_memory(self): 54 | # if self.inputs is not None: 55 | # self.inputs = self._pin_memory(self.inputs) 56 | 57 | # self.span = self._pin_memory(self.span) 58 | self.feat = self._pin_memory(self.feat) 59 | self.label = self._pin_memory(self.label) 60 | return self 61 | 62 | def _pin_memory(self, x): 63 | if isinstance(x, torch.Tensor): 64 | return x.pin_memory() 65 | elif isinstance(x, dict): 66 | return {k: self._pin_memory(v) for k, v in x.items()} 67 | elif isinstance(x, list): 68 | return [self._pin_memory(_x) for _x in x] 69 | elif isinstance(x, tuple): 70 | return tuple(self._pin_memory(_x) for _x in x) 71 | else: 72 | raise ValueError 73 | 74 | def to_device(self, device): 75 | if self.inputs is not None: 76 | self.inputs = self._to_device(self.inputs, device) 77 | 78 | # self.span = self._to_device(self.span) 79 | self.feat = self._to_device(self.feat, device) 80 | self.label = self._to_device(self.label, device) 81 | return self 82 | 83 | def _to_device(self, x, device): 84 | if isinstance(x, torch.Tensor): 85 | return x.to(device) 86 | elif isinstance(x, dict): 87 | return {k: self._to_device(v, device) for k, v in x.items()} 88 | elif isinstance(x, list): 89 | return [self._to_device(_x, device) for _x in x] 90 | elif isinstance(x, tuple): 91 | return tuple(self._to_device(_x, device) for _x in x) 92 | elif isinstance(x, transformers.tokenization_utils_base.BatchEncoding): 93 | return x.to(device) 94 | else: 95 | raise ValueError 96 | -------------------------------------------------------------------------------- /src/data/datamodule.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List, Union 3 | 4 | import pytorch_lightning as pl 5 | from torch.utils.data import DataLoader 6 | 7 | from data.batch import Batch 8 | from data.dataset import RSTDT, InstrDT 9 | from models.parser import ParserBase 10 | 11 | 12 | class DataModule(pl.LightningDataModule): 13 | def __init__( 14 | self, 15 | corpus: str, 16 | data_dir: Union[Path, str], 17 | train_file: Union[str, None] = None, 18 | valid_file: Union[str, None] = None, 19 | test_file: Union[str, None] = None, 20 | parser: Union[ParserBase, None] = None, 21 | batch_unit_type: str = "document", 22 | batch_size: int = 1, 23 | num_workers: int = 0, 24 | disable_span_level_validation: bool = False, 25 | ): 26 | super(DataModule, self).__init__() 27 | self.corpus = corpus 28 | self.parser = parser 29 | self.data_dir = data_dir if isinstance(data_dir, Path) else Path(data_dir) 30 | self.train_file = train_file 31 | self.valid_file = valid_file 32 | self.test_file = test_file 33 | self.batch_unit_type = batch_unit_type 34 | self.batch_size = batch_size 35 | self.num_workers = num_workers 36 | self.disable_span_level_validation = disable_span_level_validation 37 | 38 | if self.batch_unit_type != "span_fast" and self.batch_size != 1: 39 | raise ValueError( 40 | "Please use the `--accumulate-grad-batches` instead of the --batch-size." 41 | ) 42 | 43 | self.train_dataset = None 44 | self.valid_dataset = None 45 | self.test_dataset = None 46 | 47 | self.setup() 48 | 49 | @classmethod 50 | def from_config(cls, config, parser: ParserBase): 51 | if not hasattr(config, "batch_unit_type"): 52 | # for test 53 | config.batch_unit_type = "document" 54 | 55 | params = { 56 | "corpus": config.corpus, 57 | "parser": parser, 58 | "data_dir": config.data_dir, 59 | "train_file": config.train_file, 60 | "valid_file": config.valid_file, 61 | "test_file": config.test_file, 62 | "batch_unit_type": config.batch_unit_type, 63 | "batch_size": config.batch_size, 64 | "num_workers": config.num_workers, 65 | "disable_span_level_validation": getattr( 66 | config, "disable_span_level_validation", False 67 | ), 68 | } 69 | return cls(**params) 70 | 71 | def setup(self, stage=None): 72 | corpus2DATASET = { 73 | "RSTDT": RSTDT, 74 | "InstrDT": InstrDT, 75 | } 76 | assert self.corpus in corpus2DATASET 77 | DATASET = corpus2DATASET[self.corpus] 78 | 79 | if self.train_file is not None: 80 | self.train_dataset = DATASET(self.data_dir / self.train_file) 81 | if self.parser is not None: 82 | self.train_dataset.numericalize_document(self.parser.classifier.encoder) 83 | 84 | if self.valid_file is not None: 85 | self.valid_dataset = DATASET(self.data_dir / self.valid_file) 86 | if self.parser is not None: 87 | self.valid_dataset.numericalize_document(self.parser.classifier.encoder) 88 | 89 | if self.test_file is not None: 90 | self.test_dataset = DATASET(self.data_dir / self.test_file) 91 | if self.parser is not None: 92 | self.test_dataset.numericalize_document(self.parser.classifier.encoder) 93 | 94 | return 95 | 96 | def set_parser(self, parser: ParserBase): 97 | self.parser = parser 98 | 99 | def train_dataloader(self): 100 | if self.train_dataset is None: 101 | return None 102 | unit_type = self.batch_unit_type 103 | if unit_type == "span_fast": 104 | dataloader_batch_size = 1 105 | samples_batch_size = self.batch_size 106 | else: 107 | dataloader_batch_size = self.batch_size 108 | samples_batch_size = None 109 | 110 | samples = self.parser.generate_training_samples( 111 | self.train_dataset, unit_type, samples_batch_size 112 | ) 113 | return DataLoader( 114 | samples, 115 | batch_size=dataloader_batch_size, 116 | num_workers=self.num_workers, 117 | shuffle=True, 118 | collate_fn=collate_fn_wrapper(unit_type), 119 | pin_memory=False, 120 | ) 121 | 122 | def val_dataloader(self): 123 | if self.valid_dataset is None: 124 | return None 125 | batch_size = 1 if self.batch_unit_type == "span_fast" else self.batch_size 126 | assert batch_size == 1 127 | doc_samples = self.parser.generate_training_samples(self.valid_dataset, "document") 128 | doc_dataloader = DataLoader( 129 | doc_samples, 130 | batch_size=batch_size, 131 | num_workers=self.num_workers, 132 | shuffle=False, 133 | collate_fn=collate_fn_wrapper("document"), 134 | pin_memory=False, 135 | ) 136 | 137 | if self.disable_span_level_validation: 138 | return doc_dataloader 139 | 140 | span_samples = self.parser.generate_training_samples(self.valid_dataset, "span") 141 | span_dataloader = DataLoader( 142 | span_samples, 143 | batch_size=batch_size, 144 | num_workers=self.num_workers, 145 | shuffle=False, 146 | collate_fn=collate_fn_wrapper("span"), 147 | pin_memory=False, 148 | ) 149 | 150 | return [doc_dataloader, span_dataloader] 151 | 152 | def test_dataloader(self): 153 | if self.test_dataset is None: 154 | return None 155 | 156 | unit_type = "document" 157 | batch_size = 1 if self.batch_unit_type == "span_fast" else self.batch_size 158 | assert batch_size == 1 159 | samples = self.parser.generate_training_samples(self.test_dataset, unit_type) 160 | return DataLoader( 161 | samples, 162 | batch_size=self.batch_size, 163 | num_workers=self.num_workers, 164 | shuffle=False, 165 | collate_fn=collate_fn_wrapper(unit_type="document"), 166 | pin_memory=False, 167 | ) 168 | 169 | 170 | class collate_fn_wrapper: 171 | def __init__(self, unit_type: str): 172 | self.unit_type = unit_type 173 | 174 | def __call__(self, samples: List[dict]): 175 | # data[i]['doc']: Doc 176 | # data[i]['span']: dict 177 | # data[i]['label']: dict 178 | return Batch(samples, unit_type=self.unit_type) 179 | -------------------------------------------------------------------------------- /src/data/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import Counter 3 | from pathlib import Path 4 | from typing import Dict, List, Union 5 | 6 | import torch 7 | from torchtext.vocab import vocab 8 | 9 | from data.doc import Doc 10 | from data.rstdt_relation import re_categorize as rstdt_re_categorize 11 | from data.tree import AttachTree, RSTTree 12 | from data.utils import is_json_file, is_jsonl_file 13 | from models.encoder import Encoder 14 | 15 | 16 | class Dataset(torch.utils.data.Dataset): 17 | nucleus_vocab = vocab( 18 | Counter(["nucleus-satellite", "satellite-nucleus", "nucleus-nucleus"]), 19 | specials=[""], 20 | ) 21 | action_vocab = vocab(Counter(["shift", "reduce"]), specials=[""]) 22 | act_nuc_vocab = vocab( 23 | Counter( 24 | [ 25 | "shift_", 26 | "reduce_nucleus-satellite", 27 | "reduce_satellite-nucleus", 28 | "reduce_nucleus-nucleus", 29 | ] 30 | ), 31 | specials=[""], 32 | ) 33 | 34 | def __init__(self, file_path: Union[Path, str]): 35 | if isinstance(file_path, str): 36 | file_path = Path(file_path) 37 | 38 | raw_dataset = self.load(file_path) 39 | dataset: List[Doc] = self.preprocess(raw_dataset) 40 | self.dataset = dataset 41 | 42 | def __getitem__(self, idx): 43 | return self.dataset[idx] 44 | 45 | def __len__(self): 46 | return len(self.dataset) 47 | 48 | def load(self, file_path): 49 | if is_json_file(file_path): 50 | return self.load_from_json(file_path) 51 | if is_jsonl_file(file_path): 52 | return self.load_from_jsonl(file_path) 53 | 54 | raise NotImplementedError 55 | 56 | def numericalize_document(self, encoder: Encoder): 57 | for doc in self.dataset: 58 | inputs = encoder.apply_tokenizer(doc) 59 | doc.inputs = inputs 60 | 61 | return 62 | 63 | @staticmethod 64 | def load_from_json(file_path): 65 | with open(file_path) as f: 66 | dataset = json.load(f) 67 | return dataset 68 | 69 | @staticmethod 70 | def load_from_jsonl(file_path): 71 | dataset = [] 72 | with open(file_path) as f: 73 | for line in f: 74 | data = json.loads(f) 75 | dataset.append(data) 76 | 77 | return dataset 78 | 79 | 80 | class RSTDT(Dataset): 81 | relation_vocab = vocab( 82 | Counter( 83 | [ 84 | "Elaboration", 85 | "Attribution", 86 | "Joint", 87 | "Same-unit", 88 | "Contrast", 89 | "Explanation", 90 | "Background", 91 | "Cause", 92 | "Enablement", 93 | "Evaluation", 94 | "Temporal", 95 | "Condition", 96 | "Comparison", 97 | "Topic-Change", 98 | "Summary", 99 | "Manner-Means", 100 | "Textual-organization", 101 | "Topic-Comment", 102 | ] 103 | ), 104 | specials=[""], 105 | ) 106 | fully_label_vocab = vocab( # from TRAINING (TRAIN+DEV) 107 | Counter( 108 | [ 109 | # N-S 110 | "nucleus-satellite:Elaboration", 111 | "nucleus-satellite:Attribution", 112 | "nucleus-satellite:Explanation", 113 | "nucleus-satellite:Enablement", 114 | "nucleus-satellite:Background", 115 | "nucleus-satellite:Evaluation", 116 | "nucleus-satellite:Cause", 117 | "nucleus-satellite:Contrast", 118 | "nucleus-satellite:Condition", 119 | "nucleus-satellite:Comparison", 120 | "nucleus-satellite:Manner-Means", 121 | "nucleus-satellite:Summary", 122 | "nucleus-satellite:Temporal", 123 | "nucleus-satellite:Topic-Comment", 124 | "nucleus-satellite:Topic-Change", 125 | # S-N 126 | "satellite-nucleus:Attribution", 127 | "satellite-nucleus:Contrast", 128 | "satellite-nucleus:Background", 129 | "satellite-nucleus:Condition", 130 | "satellite-nucleus:Cause", 131 | "satellite-nucleus:Evaluation", 132 | "satellite-nucleus:Temporal", 133 | "satellite-nucleus:Explanation", 134 | "satellite-nucleus:Enablement", 135 | "satellite-nucleus:Comparison", 136 | "satellite-nucleus:Elaboration", 137 | "satellite-nucleus:Manner-Means", 138 | "satellite-nucleus:Summary", 139 | "satellite-nucleus:Topic-Comment", 140 | # N-N 141 | "nucleus-nucleus:Joint", 142 | "nucleus-nucleus:Same-unit", 143 | "nucleus-nucleus:Contrast", 144 | "nucleus-nucleus:Temporal", 145 | "nucleus-nucleus:Topic-Change", 146 | "nucleus-nucleus:Textual-organization", 147 | "nucleus-nucleus:Comparison", 148 | "nucleus-nucleus:Topic-Comment", 149 | "nucleus-nucleus:Cause", 150 | "nucleus-nucleus:Condition", 151 | "nucleus-nucleus:Explanation", 152 | "nucleus-nucleus:Evaluation", 153 | ] 154 | ), 155 | specials=[""], 156 | ) 157 | 158 | def preprocess(self, raw_dataset: List[Dict]): 159 | dataset = [] 160 | for data in raw_dataset: 161 | rst_tree = RSTTree.fromstring(data["rst_tree"]) 162 | rst_tree = rstdt_re_categorize(rst_tree) 163 | assert RSTTree.check_relation(rst_tree, self.relation_vocab) 164 | bi_rst_tree = RSTTree.binarize(rst_tree) 165 | attach_tree = RSTTree.convert_to_attach(bi_rst_tree) 166 | data["attach_tree"] = attach_tree 167 | # (wsj_1189 has annotateion error) 168 | if data["doc_id"] != "wsj_1189": # check conversion 169 | assert bi_rst_tree == AttachTree.convert_to_rst(attach_tree) 170 | 171 | tokenized_edu_strings = [] 172 | edu_starts_sentence = [] 173 | 174 | tokens = data["tokens"] 175 | edu_start_indices = data["edu_start_indices"] 176 | sent_id, token_id, _ = edu_start_indices[0] 177 | for next_sent_id, next_token_id, _ in edu_start_indices[1:] + [(-1, -1, -1)]: 178 | end_token_id = next_token_id if token_id < next_token_id else None 179 | tokenized_edu_strings.append(tokens[sent_id][token_id:end_token_id]) 180 | edu_starts_sentence.append(token_id == 0) 181 | sent_id = next_sent_id 182 | token_id = next_token_id 183 | 184 | data["tokenized_edu_strings"] = tokenized_edu_strings 185 | data["edu_starts_sentence"] = edu_starts_sentence 186 | 187 | doc = Doc.from_data(data) 188 | dataset.append(doc) 189 | 190 | return dataset 191 | 192 | 193 | class InstrDT(Dataset): 194 | relation_vocab = vocab( 195 | Counter( 196 | [ 197 | "preparation:act", 198 | "joint", 199 | "general:specific", 200 | "criterion:act", 201 | "goal:act", 202 | "act:goal", 203 | "textualorganization", 204 | "topic-change?", 205 | "step1:step2", 206 | "disjunction", 207 | "contrast1:contrast2", 208 | "co-temp1:co-temp2", 209 | "act:reason", 210 | "act:criterion", 211 | "cause:effect", 212 | "comparision", 213 | "reason:act", 214 | "act:preparation", 215 | "situation:circumstance", 216 | "same-unit", 217 | "object:attribute", 218 | "effect:cause", 219 | "prescribe-act:wrong-act", 220 | "indeterminate", 221 | "specific:general", 222 | "before:after", 223 | "set:member", 224 | "situation:obstacle", 225 | "wrong-act:prescribe-act", 226 | "act:constraint", 227 | "circumstance:situation", 228 | "act:side-effect", 229 | "obstacle:situation", 230 | "after:before", 231 | "side-effect:act", 232 | "wrong-act:criterion", 233 | "attribute:object", 234 | "criterion:wrong-act", 235 | "constraint:act", 236 | ] 237 | ), 238 | specials=[""], 239 | ) 240 | 241 | def preprocess(self, raw_dataset: List[Dict]): 242 | dataset = [] 243 | for data in raw_dataset: 244 | rst_tree = RSTTree.fromstring(data["rst_tree"]) 245 | if not RSTTree.is_valid_tree(rst_tree): 246 | continue 247 | assert RSTTree.check_relation(rst_tree, self.relation_vocab) 248 | bi_rst_tree = RSTTree.binarize(rst_tree) 249 | attach_tree = RSTTree.convert_to_attach(bi_rst_tree) 250 | data["attach_tree"] = attach_tree 251 | assert bi_rst_tree == AttachTree.convert_to_rst(attach_tree) 252 | 253 | doc = Doc.from_data(data) 254 | dataset.append(doc) 255 | 256 | return dataset 257 | -------------------------------------------------------------------------------- /src/data/doc.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | from spacy.lang.en import English 4 | 5 | from data.edu import EDU 6 | from data.tree import AttachTree, RSTTree 7 | 8 | nlp = English() 9 | 10 | 11 | class Doc(object): 12 | def __init__(self, edus: List[EDU], tree: Union[RSTTree, AttachTree], doc_id: str): 13 | self.edus = edus 14 | self.tree = tree 15 | self.doc_id = doc_id 16 | self.inputs = None # numericalized edus 17 | 18 | def __repr__(self): 19 | return 'Doc(doc_id: "{}", tree: {}, edus: {})'.format( 20 | self.doc_id, type(self.tree).__name__, [edu.edu_string for edu in self.edus] 21 | ) 22 | 23 | def get_edu_strings(self): 24 | return [edu.edu_string for edu in self.edus] 25 | 26 | @classmethod 27 | def from_data(cls, data: dict): 28 | assert "doc_id" in data 29 | doc_id = data["doc_id"] 30 | 31 | assert "attach_tree" in data 32 | tree = data["attach_tree"] 33 | 34 | assert "edu_strings" in data 35 | edu_strings = data["edu_strings"] # non-tokenized 36 | 37 | assert "edu_starts_sentence" in data 38 | assert "edu_starts_paragraph" in data 39 | edu_starts_sentence = data["edu_starts_sentence"] 40 | edu_starts_paragraph = data["edu_starts_paragraph"] 41 | edu_ends_sentence = edu_starts_sentence[1:] + [True] 42 | edu_ends_paragraph = edu_starts_paragraph[1:] + [True] 43 | 44 | if "tokenized_edu_strings" in data: 45 | tokenized_edu_strings = data["tokenized_edu_strings"] 46 | else: 47 | tokenized_edu_strings = [[token.text for token in nlp(edu)] for edu in edu_strings] 48 | 49 | edus = [] 50 | sent_idx, para_idx = 0, 0 51 | for edu_idx, edu_string in enumerate(edu_strings): 52 | edu_tokens = tokenized_edu_strings[edu_idx] 53 | 54 | is_start_sent = edu_starts_sentence[edu_idx] 55 | is_end_sent = edu_ends_sentence[edu_idx] 56 | 57 | is_start_para = edu_starts_paragraph[edu_idx] 58 | is_end_para = edu_ends_paragraph[edu_idx] 59 | 60 | is_start_doc = True if edu_idx == 0 else False 61 | is_end_doc = True if edu_idx == len(edu_string) - 1 else False 62 | 63 | if is_start_sent: 64 | sent_idx += 1 65 | if is_start_para: 66 | para_idx += 1 67 | 68 | edu = EDU( 69 | edu_string, 70 | edu_tokens, 71 | sent_idx, 72 | para_idx, 73 | is_start_sent, 74 | is_end_sent, 75 | is_start_para, 76 | is_end_para, 77 | is_start_doc, 78 | is_end_doc, 79 | ) 80 | edus.append(edu) 81 | 82 | return cls(edus, tree, doc_id) 83 | -------------------------------------------------------------------------------- /src/data/edu.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | 4 | class EDU(object): 5 | def __init__( 6 | self, 7 | edu_string: str, 8 | tokens: List[str], 9 | sent_idx: int, 10 | para_idx: int, 11 | is_start_sent: bool, 12 | is_end_sent: bool, 13 | is_start_para: bool, 14 | is_end_para: bool, 15 | is_start_doc: bool, 16 | is_end_doc: bool, 17 | ): 18 | self.edu_string = edu_string 19 | self.tokens = tokens 20 | self.sent_idx = sent_idx 21 | self.para_idx = para_idx 22 | self.is_start_sent = is_start_sent 23 | self.is_end_sent = is_end_sent 24 | self.is_start_para = is_start_para 25 | self.is_end_para = is_end_para 26 | self.is_start_doc = is_start_doc 27 | self.is_end_doc = is_end_doc 28 | 29 | def __repr__(self): 30 | return 'EDU(edu_string: "{}", tokens: {}, flags: [{}, {}, {}, {}, {}, {}])'.format( 31 | self.edu_string, 32 | self.tokens, 33 | self.is_start_sent, 34 | self.is_end_sent, 35 | self.is_start_para, 36 | self.is_end_para, 37 | self.is_start_doc, 38 | self.is_end_doc, 39 | ) 40 | -------------------------------------------------------------------------------- /src/data/rstdt_relation.py: -------------------------------------------------------------------------------- 1 | from data.tree import RSTTree 2 | 3 | 4 | def re_categorize(tree: RSTTree): 5 | def helper(node): 6 | if not isinstance(node, RSTTree): 7 | return node 8 | 9 | label = node.label() 10 | if label not in ["ROOT", "text"]: 11 | nuc, rel = node.label().split(":", maxsplit=1) 12 | while rel[-2:] in ["-s", "-e", "-n"]: 13 | rel = rel[:-2] 14 | 15 | rel = RELATION_TABLE[rel] 16 | label = ":".join([nuc, rel]) 17 | 18 | return RSTTree(label, [helper(child) for child in node]) 19 | 20 | assert isinstance(tree, RSTTree) 21 | return helper(tree) 22 | 23 | 24 | RELATION_TABLE = { 25 | "ROOT": "ROOT", 26 | "span": "span", 27 | "attribution": "Attribution", 28 | "attribution-negative": "Attribution", 29 | "background": "Background", 30 | "circumstance": "Background", 31 | "cause": "Cause", 32 | "result": "Cause", 33 | "cause-result": "Cause", 34 | "consequence": "Cause", 35 | "comparison": "Comparison", 36 | "preference": "Comparison", 37 | "analogy": "Comparison", 38 | "proportion": "Comparison", 39 | "condition": "Condition", 40 | "hypothetical": "Condition", 41 | "contingency": "Condition", 42 | "otherwise": "Condition", 43 | "contrast": "Contrast", 44 | "concession": "Contrast", 45 | "antithesis": "Contrast", 46 | "elaboration-additional": "Elaboration", 47 | "elaboration-general-specific": "Elaboration", 48 | "elaboration-part-whole": "Elaboration", 49 | "elaboration-process-step": "Elaboration", 50 | "elaboration-object-attribute": "Elaboration", 51 | "elaboration-set-member": "Elaboration", 52 | "example": "Elaboration", 53 | "definition": "Elaboration", 54 | "enablement": "Enablement", 55 | "purpose": "Enablement", 56 | "evaluation": "Evaluation", 57 | "interpretation": "Evaluation", 58 | "conclusion": "Evaluation", 59 | "comment": "Evaluation", 60 | "evidence": "Explanation", 61 | "explanation-argumentative": "Explanation", 62 | "reason": "Explanation", 63 | "list": "Joint", 64 | "disjunction": "Joint", 65 | "manner": "Manner-Means", 66 | "means": "Manner-Means", 67 | "problem-solution": "Topic-Comment", 68 | "question-answer": "Topic-Comment", 69 | "statement-response": "Topic-Comment", 70 | "topic-comment": "Topic-Comment", 71 | "comment-topic": "Topic-Comment", 72 | "rhetorical-question": "Topic-Comment", 73 | "summary": "Summary", 74 | "restatement": "Summary", 75 | "temporal-before": "Temporal", 76 | "temporal-after": "Temporal", 77 | "temporal-same-time": "Temporal", 78 | "sequence": "Temporal", 79 | "inverted-sequence": "Temporal", 80 | "topic-shift": "Topic-Change", 81 | "topic-drift": "Topic-Change", 82 | "textualorganization": "Textual-organization", 83 | "same-unit": "Same-unit", 84 | } 85 | -------------------------------------------------------------------------------- /src/data/tree.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from nltk import Tree 4 | from torchtext.vocab import Vocab 5 | 6 | 7 | class RSTTree(Tree): 8 | def __init__(self, label: str, children: List): 9 | super(RSTTree, self).__init__(label, children) 10 | self.nuc = None 11 | self.rel = None 12 | if label not in ["ROOT", "text"]: 13 | nuc, rel = label.split(":", maxsplit=1) 14 | 15 | @classmethod 16 | def binarize(cls, tree: Tree): 17 | def helper(node): 18 | if len(node) == 1: 19 | # End of recursion 20 | return node 21 | elif len(node) == 2: 22 | # Binary structure 23 | left_node = helper(node[0]) 24 | right_node = helper(node[1]) 25 | else: 26 | # Non-Binary structure 27 | labels = [node[i].label() for i in range(len(node))] 28 | is_polynuclear = all(map(lambda x: x == labels[0], labels)) 29 | if is_polynuclear: 30 | # Polynuclear relation label such as: 31 | # same-unit, list, etc... 32 | # -> convert to right heavy structure 33 | left_node = helper(node[0]) 34 | right_node = helper( 35 | cls(node[0].label(), [node[i] for i in range(1, len(node))]) 36 | ) 37 | else: 38 | # Non Binary structure without Polynuclear label 39 | # S/N/S -> left heavy 40 | left_node = helper(cls("nucleus:span", [node[0], node[1]])) 41 | right_node = helper(node[2]) 42 | 43 | return cls(node.label(), [left_node, right_node]) 44 | 45 | assert isinstance(tree, RSTTree) 46 | return helper(tree) 47 | 48 | @classmethod 49 | def is_binary(cls, tree: Tree): 50 | def helper(node): 51 | if not isinstance(node, RSTTree): 52 | return True 53 | elif len(node) > 2: 54 | return False 55 | else: 56 | return all([helper(child) for child in node]) 57 | 58 | assert isinstance(tree, RSTTree) 59 | return helper(tree) 60 | 61 | @classmethod 62 | def convert_to_attach(cls, tree: Tree): 63 | def helper(node): 64 | if len(node) == 1: 65 | edu_idx = node[0][0] 66 | return AttachTree("text", [edu_idx]) 67 | 68 | l_nuc, l_rel = node[0].label().split(":", maxsplit=1) 69 | r_nuc, r_rel = node[1].label().split(":", maxsplit=1) 70 | nuc = "-".join([l_nuc, r_nuc]) 71 | rel = l_rel if l_rel != "span" else r_rel 72 | label = ":".join([nuc, rel]) 73 | return AttachTree(label, [helper(child) for child in node]) 74 | 75 | assert RSTTree.is_binary(tree) 76 | assert isinstance(tree, RSTTree) 77 | return helper(tree) 78 | 79 | @classmethod 80 | def is_valid_tree(cls, tree: Tree): 81 | assert isinstance(tree, RSTTree) 82 | if len(tree) == 1: 83 | # (ROOT (text 0)) 84 | return False 85 | 86 | return True 87 | 88 | @classmethod 89 | def check_relation(cls, tree: Tree, relation_vocab: Vocab): 90 | for tp in tree.treepositions(): 91 | node = tree[tp] 92 | if not isinstance(node, RSTTree): 93 | continue 94 | 95 | label = tree.label() 96 | if label in ["ROOT", "text"]: 97 | continue 98 | 99 | nuc, rel = label.split(":", maxsplit=1) 100 | if rel not in relation_vocab: 101 | return False 102 | 103 | return True 104 | 105 | def get_brackets(self, eval_types: List[str]): 106 | brackets = {eval_type: [] for eval_type in eval_types} 107 | 108 | for tp in self.treepositions(): 109 | node = self[tp] 110 | if not isinstance(node, RSTTree): 111 | continue # EDU idx 112 | 113 | label = node.label() 114 | if label == "ROOT" and tp == (): 115 | continue # ROOT node 116 | 117 | if label == "text" and len(node) == 1: 118 | continue # leave node 119 | 120 | edu_indices = [int(idx) for idx in node.leaves()] 121 | span = (edu_indices[0], edu_indices[-1] + 1) 122 | ns, relation = label.split(":", maxsplit=1) 123 | 124 | if "full" in eval_types: 125 | brackets["full"].append((span, ns, relation)) 126 | if "rel" in eval_types: 127 | brackets["rel"].append((span, relation)) 128 | if "nuc" in eval_types: 129 | brackets["nuc"].append((span, ns)) 130 | if "span" in eval_types: 131 | brackets["span"].append((span)) 132 | 133 | return brackets 134 | 135 | 136 | class AttachTree(Tree): 137 | def __init__(self, label: str, children: List): 138 | super(AttachTree, self).__init__(label, children) 139 | self.nuc = None 140 | self.rel = None 141 | if label == "text": 142 | # EDU node 143 | pass 144 | else: 145 | nuc, rel = label.split(":", maxsplit=1) 146 | 147 | @classmethod 148 | def convert_to_rst(cls, tree: Tree): 149 | def helper(node, label="ROOT"): 150 | if len(node) == 1: 151 | edu_idx = node[0] 152 | return RSTTree(label, [RSTTree("text", [edu_idx])]) 153 | 154 | nuc, rel = node.label().split(":", maxsplit=1) 155 | if len(nuc.split("-")) == 1: 156 | raise ValueError("Invalid nucleus label: {}".format(nuc)) 157 | l_nuc, r_nuc = nuc.split("-") 158 | if nuc == "nucleus-satellite": 159 | l_rel, r_rel = "span", rel 160 | elif nuc == "satellite-nucleus": 161 | l_rel, r_rel = rel, "span" 162 | elif nuc == "nucleus-nucleus": 163 | l_rel = r_rel = rel 164 | else: 165 | raise ValueError("Unkwon Nucleus label: {}".format(nuc)) 166 | 167 | l_label = ":".join([l_nuc, l_rel]) 168 | r_label = ":".join([r_nuc, r_rel]) 169 | return RSTTree(label, [helper(node[0], l_label), helper(node[1], r_label)]) 170 | 171 | assert isinstance(tree, AttachTree) 172 | return helper(tree) 173 | 174 | def get_brackets(self, eval_types: List[str]): 175 | brackets = {eval_type: [] for eval_type in eval_types} 176 | 177 | for tp in self.treepositions(): 178 | node = self[tp] 179 | if not isinstance(node, AttachTree): 180 | continue # EDU idx 181 | 182 | label = node.label() 183 | if label == "text" and len(node) == 1: 184 | continue # leave node 185 | 186 | edu_indices = [int(idx) for idx in node.leaves()] 187 | span = (edu_indices[0], edu_indices[-1] + 1) 188 | ns, relation = label.split(":", maxsplit=1) 189 | 190 | if "full" in eval_types: 191 | brackets["full"].append((span, ns, relation)) 192 | if "rel" in eval_types: 193 | brackets["rel"].append((span, relation)) 194 | if "nuc" in eval_types: 195 | brackets["nuc"].append((span, ns)) 196 | if "span" in eval_types: 197 | brackets["span"].append((span)) 198 | 199 | return brackets 200 | -------------------------------------------------------------------------------- /src/data/utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | 4 | def is_json_file(file_path: Path): 5 | if file_path.suffix == ".json": 6 | return True 7 | return False 8 | 9 | 10 | def is_jsonl_file(file_path: Path): 11 | if file_path.suffix == ".jsonl": 12 | return True 13 | return False 14 | -------------------------------------------------------------------------------- /src/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from metrics.parseval import Parseval 2 | from metrics.rst_parseval import RSTParseval 3 | from metrics.original_parseval import OriginalParseval 4 | 5 | __all__ = ["Parseval", "OriginalParseval", "RSTParseval"] 6 | -------------------------------------------------------------------------------- /src/metrics/original_parseval.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from data.tree import AttachTree, RSTTree 4 | from metrics import Parseval 5 | 6 | 7 | class OriginalParseval(Parseval): 8 | def __init__(self): 9 | super(OriginalParseval, self).__init__() 10 | 11 | def convert_tree(self, tree: Union[RSTTree, AttachTree]): 12 | # RSTTree -> AttachTree 13 | if isinstance(tree, RSTTree): 14 | tree = RSTTree.convert_to_attach(tree) 15 | 16 | return tree 17 | -------------------------------------------------------------------------------- /src/metrics/parseval.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | import torch 4 | from torchmetrics import Metric 5 | 6 | from data.tree import AttachTree, RSTTree 7 | 8 | 9 | class Parseval(Metric): 10 | def __init__(self): 11 | super(Parseval, self).__init__(compute_on_step=False) 12 | self.eval_types = ["span", "nuc", "rel", "full"] 13 | for eval_type in self.eval_types: 14 | self.add_state( 15 | "match_{}".format(eval_type), 16 | default=torch.tensor(0), 17 | dist_reduce_fx="sum", 18 | ) 19 | 20 | self.add_state("pred", default=torch.tensor(0.0), dist_reduce_fx="sum") 21 | self.add_state("gold", default=torch.tensor(0.0), dist_reduce_fx="sum") 22 | 23 | def update( 24 | self, 25 | pred_trees: List[Union[RSTTree, AttachTree]], 26 | gold_trees: List[Union[RSTTree, AttachTree]], 27 | ): 28 | assert len(pred_trees) == len(gold_trees) 29 | 30 | for pred_tree, gold_tree in zip(pred_trees, gold_trees): 31 | # convert tree 32 | pred_tree = self.convert_tree(pred_tree) 33 | gold_tree = self.convert_tree(gold_tree) 34 | # get brackets 35 | pred_brackets = pred_tree.get_brackets(self.eval_types) 36 | gold_brackets = gold_tree.get_brackets(self.eval_types) 37 | # count brackets 38 | pred_cnt = len(pred_brackets["span"]) 39 | gold_cnt = len(gold_brackets["span"]) 40 | assert pred_cnt == gold_cnt 41 | self.pred += pred_cnt 42 | self.gold += gold_cnt 43 | 44 | self.match_span += len( 45 | [bracket for bracket in pred_brackets["span"] if bracket in gold_brackets["span"]] 46 | ) 47 | self.match_nuc += len( 48 | [bracket for bracket in pred_brackets["nuc"] if bracket in gold_brackets["nuc"]] 49 | ) 50 | self.match_rel += len( 51 | [bracket for bracket in pred_brackets["rel"] if bracket in gold_brackets["rel"]] 52 | ) 53 | self.match_full += len( 54 | [bracket for bracket in pred_brackets["full"] if bracket in gold_brackets["full"]] 55 | ) 56 | 57 | def compute(self): 58 | metric_name = self.__class__.__name__ 59 | return { 60 | "{}-S".format(metric_name): self.match_span / self.pred, 61 | "{}-N".format(metric_name): self.match_nuc / self.pred, 62 | "{}-R".format(metric_name): self.match_rel / self.pred, 63 | "{}-F".format(metric_name): self.match_full / self.pred, 64 | } 65 | 66 | def convert_tree(self, tree: Union[RSTTree, AttachTree]): 67 | raise NotImplementedError 68 | -------------------------------------------------------------------------------- /src/metrics/rst_parseval.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from data.tree import AttachTree, RSTTree 4 | from metrics import Parseval 5 | 6 | 7 | class RSTParseval(Parseval): 8 | def __init__(self): 9 | super(RSTParseval, self).__init__() 10 | 11 | def convert_tree(self, tree: Union[RSTTree, AttachTree]): 12 | # AttachTree -> RSTTree 13 | if isinstance(tree, AttachTree): 14 | tree = AttachTree.convert_to_rst(tree) 15 | 16 | return tree 17 | -------------------------------------------------------------------------------- /src/models/classifier/__init__.py: -------------------------------------------------------------------------------- 1 | from models.classifier.classifier_base import ClassifierBase 2 | from models.classifier.shift_reduce_classifier_base import ShiftReduceClassifierBase 3 | from models.classifier.shift_reduce_classifier_v1 import ShiftReduceClassifierV1 4 | from models.classifier.shift_reduce_classifier_v2 import ShiftReduceClassifierV2 5 | from models.classifier.shift_reduce_classifier_v3 import ShiftReduceClassifierV3 6 | from models.classifier.top_down_classifier_base import TopDownClassifierBase 7 | from models.classifier.top_down_classifier_v1 import TopDownClassifierV1 8 | from models.classifier.top_down_classifier_v2 import TopDownClassifierV2 9 | 10 | 11 | class Classifiers: 12 | classifier_dict = { 13 | "top_down_v1": TopDownClassifierV1, 14 | "top_down_v2": TopDownClassifierV2, 15 | "shift_reduce_v1": ShiftReduceClassifierV1, 16 | "shift_reduce_v2": ShiftReduceClassifierV2, 17 | "shift_reduce_v3": ShiftReduceClassifierV3, 18 | } 19 | 20 | @classmethod 21 | def from_config(cls, config): 22 | classifier_type = config.model_type 23 | classifier = cls.classifier_dict[classifier_type].from_config(config) 24 | return classifier 25 | 26 | 27 | __all__ = [ 28 | "ClassifierBase", 29 | "TopDownClassifierBase", 30 | "ShiftReduceClassifierBase", 31 | "TopDownClassifierV1", 32 | "TopDownClassifierV2", 33 | "ShiftReduceClassifierV1", 34 | "ShiftReduceClassifierV2", 35 | "ShiftReduceClassifierV3", 36 | ] 37 | -------------------------------------------------------------------------------- /src/models/classifier/classifier_base.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | import pytorch_lightning as pl 4 | from torch.optim import AdamW 5 | from transformers import get_linear_schedule_with_warmup 6 | 7 | from data.batch import Batch 8 | from data.dataset import RSTDT, InstrDT 9 | from metrics import OriginalParseval, RSTParseval 10 | from models.encoder import BertEncoder 11 | 12 | 13 | class ClassifierBase(pl.LightningModule): 14 | def __init__( 15 | self, 16 | model_type: str, 17 | bert_model_name: str, 18 | bert_max_length: int, 19 | bert_stride: int, 20 | corpus: str, 21 | accumulate_grad_batches: int, 22 | batch_unit_type: str, 23 | lr_for_encoder: float, 24 | lr: float, 25 | disable_lr_schedule: bool = False, 26 | disable_org_sent: bool = False, 27 | disable_org_para: bool = False, 28 | ): 29 | super(ClassifierBase, self).__init__() 30 | self.save_hyperparameters() 31 | 32 | self.lr = lr 33 | self.lr_for_encoder = lr_for_encoder 34 | self.disable_lr_schedule = disable_lr_schedule 35 | 36 | corpus2DATASET = { 37 | "RSTDT": RSTDT, 38 | "InstrDT": InstrDT, 39 | } 40 | assert corpus in corpus2DATASET 41 | self.DATASET = corpus2DATASET[corpus] 42 | 43 | self.met_rst_parseval = RSTParseval() 44 | self.met_ori_parseval = OriginalParseval() 45 | 46 | self.met_rst_parseval_oracle = RSTParseval() 47 | self.met_ori_parseval_oracle = OriginalParseval() 48 | 49 | self.parser = None 50 | self.encoder = BertEncoder(bert_model_name, bert_max_length, bert_stride) 51 | 52 | self.disable_org_sent = disable_org_sent 53 | self.disable_org_para = disable_org_para 54 | self.disable_org_feat = self.disable_org_sent and self.disable_org_para 55 | 56 | @classmethod 57 | def from_config(cls, config): 58 | params = cls.params_from_config(config) 59 | return cls(**params) 60 | 61 | @classmethod 62 | def params_from_config(cls, config): 63 | return { 64 | "model_type": config.model_type, 65 | "bert_model_name": config.bert_model_name, 66 | "bert_max_length": config.bert_max_length, 67 | "bert_stride": config.bert_stride, 68 | "corpus": config.corpus, 69 | "accumulate_grad_batches": config.accumulate_grad_batches, 70 | "batch_unit_type": config.batch_unit_type, 71 | "lr_for_encoder": config.lr_for_encoder, 72 | "lr": config.lr, 73 | "disable_lr_schedule": config.disable_lr_schedule, 74 | "disable_org_sent": config.disable_org_sent, 75 | "disable_org_para": config.disable_org_para, 76 | } 77 | 78 | def set_parser(self, parser): 79 | self.parser = parser 80 | 81 | def set_training_steps_par_epoch(self, training_steps): 82 | self.training_steps_par_epoch = training_steps 83 | 84 | def init_org_embeddings(self): 85 | raise NotImplementedError 86 | 87 | def get_org_embedding_dim(self): 88 | if self.disable_org_feat: 89 | return 0 90 | 91 | n_dim = self.org_embed.embedding_dim 92 | n_feat = self.org_embed.num_embeddings // 2 93 | feat_embed_dim = n_feat * n_dim 94 | return feat_embed_dim 95 | 96 | def forward(self): 97 | raise NotImplementedError 98 | 99 | def training_loss(self, batch: Batch): 100 | raise NotImplementedError 101 | 102 | def on_train_start(self): 103 | self.logger.log_hyperparams( 104 | self.hparams, 105 | { 106 | "hp_metric/RSTParseval-S": 0, 107 | "hp_metric/RSTParseval-N": 0, 108 | "hp_metric/RSTParseval-R": 0, 109 | "hp_metric/RSTParseval-F": 0, 110 | "hp_metric/OriginalParseval-S": 0, 111 | "hp_metric/OriginalParseval-N": 0, 112 | "hp_metric/OriginalParseval-R": 0, 113 | "hp_metric/OriginalParseval-F": 0, 114 | }, 115 | ) 116 | 117 | def training_step(self, batch: Batch, batch_idx: Union[int, None] = None): 118 | batch.to_device(self.device) 119 | loss_dict = self.training_loss(batch) 120 | 121 | for name, value in loss_dict.items(): 122 | self.log( 123 | "train/{}".format(name), 124 | value, 125 | on_epoch=True, 126 | on_step=True, 127 | batch_size=len(batch), 128 | ) 129 | 130 | return loss_dict["loss"] 131 | 132 | def validation_step(self, batch, batch_idx: int, dataloader_idx: Union[int, None] = None): 133 | batch.to_device(self.device) 134 | if batch.unit_type == "span": 135 | loss_dict = self.training_loss(batch) 136 | 137 | for name, value in loss_dict.items(): 138 | self.log( 139 | "valid/{}".format(name), 140 | value, 141 | on_epoch=True, 142 | on_step=False, 143 | prog_bar=True, 144 | batch_size=len(batch), 145 | ) 146 | 147 | return loss_dict["loss"] 148 | 149 | elif batch.unit_type == "document": 150 | doc = batch.doc 151 | pred_tree = self.parser.parse(doc) 152 | gold_tree = doc.tree 153 | 154 | self.met_rst_parseval([pred_tree], [gold_tree]) 155 | self.met_ori_parseval([pred_tree], [gold_tree]) 156 | 157 | pred_tree = self.parser.parse_with_naked_tree(doc, doc.tree) 158 | gold_tree = doc.tree 159 | 160 | self.met_rst_parseval_oracle([pred_tree], [gold_tree]) 161 | self.met_ori_parseval_oracle([pred_tree], [gold_tree]) 162 | 163 | return 164 | else: 165 | raise ValueError 166 | 167 | def test_step(self, batch, batch_idx=None): 168 | batch.to_device(self.device) 169 | assert batch.unit_type == "document" 170 | 171 | doc = batch.doc 172 | pred_tree = self.parser.parse(doc) 173 | gold_tree = doc.tree 174 | 175 | self.met_rst_parseval([pred_tree], [gold_tree]) 176 | self.met_ori_parseval([pred_tree], [gold_tree]) 177 | 178 | return 179 | 180 | def validation_epoch_end(self, outputs: List): 181 | scores = self.met_rst_parseval.compute() 182 | for name, value in scores.items(): 183 | self.log("valid/{}".format(name), value, prog_bar=True) 184 | self.log("hp_metric/{}".format(name), value, prog_bar=True) 185 | 186 | self.met_rst_parseval.reset() 187 | 188 | scores = self.met_ori_parseval.compute() 189 | for name, value in scores.items(): 190 | self.log("valid/{}".format(name), value, prog_bar=True) 191 | self.log("hp_metric/{}".format(name), value, prog_bar=True) 192 | 193 | self.met_ori_parseval.reset() 194 | 195 | scores = self.met_rst_parseval_oracle.compute() 196 | for name, value in scores.items(): 197 | self.log("valid/{}_oracle".format(name), value, prog_bar=False) 198 | 199 | self.met_rst_parseval_oracle.reset() 200 | 201 | scores = self.met_ori_parseval_oracle.compute() 202 | for name, value in scores.items(): 203 | self.log("valid/{}_oracle".format(name), value, prog_bar=False) 204 | 205 | self.met_ori_parseval_oracle.reset() 206 | 207 | return 208 | 209 | def test_epoch_end(self, outputs): 210 | scores = self.met_rst_parseval.compute() 211 | for name, value in scores.items(): 212 | self.log("test/{}".format(name), value, prog_bar=True) 213 | 214 | self.met_rst_parseval.reset() 215 | 216 | scores = self.met_ori_parseval.compute() 217 | for name, value in scores.items(): 218 | self.log("test/{}".format(name), value, prog_bar=True) 219 | 220 | self.met_ori_parseval.reset() 221 | 222 | return 223 | 224 | def configure_optimizers(self): 225 | no_decay = ["bias", "LayerNorm.weight"] 226 | params = [ 227 | # BERT parameters 228 | { 229 | "params": [ 230 | p 231 | for n, p in self.named_parameters() 232 | if not any(nd in n for nd in no_decay) and "encoder" in n 233 | ], 234 | "weight_decay": 0.01, 235 | "lr": self.lr_for_encoder, 236 | }, 237 | { 238 | "params": [ 239 | p 240 | for n, p in self.named_parameters() 241 | if any(nd in n for nd in no_decay) and "encoder" in n 242 | ], 243 | "weight_decay": 0.0, 244 | "lr": self.lr_for_encoder, 245 | }, 246 | # Other parameters 247 | { 248 | "params": [ 249 | p 250 | for n, p in self.named_parameters() 251 | if not any(nd in n for nd in no_decay) and "encoder" not in n 252 | ], 253 | "weight_decay": 0.01, 254 | "lr": self.lr, 255 | }, 256 | { 257 | "params": [ 258 | p 259 | for n, p in self.named_parameters() 260 | if any(nd in n for nd in no_decay) and "encoder" not in n 261 | ], 262 | "weight_decay": 0.0, 263 | "lr": self.lr, 264 | }, 265 | ] 266 | optimizer = AdamW(params) 267 | optimizers = [optimizer] 268 | 269 | if self.disable_lr_schedule: 270 | return optimizers 271 | 272 | num_epochs = self.trainer.max_epochs 273 | accum_size = self.trainer.accumulate_grad_batches 274 | num_training_steps = self.training_steps_par_epoch / accum_size * num_epochs 275 | num_warmup_steps = int(num_training_steps * 0.1) 276 | scheduler = get_linear_schedule_with_warmup( 277 | optimizer, 278 | num_warmup_steps=num_warmup_steps, 279 | num_training_steps=num_training_steps, 280 | ) 281 | 282 | schedulers = [{"scheduler": scheduler, "interval": "step", "frequency": 1}] 283 | return optimizers, schedulers 284 | -------------------------------------------------------------------------------- /src/models/classifier/linear.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class FeedForward(nn.Sequential): 5 | def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, dropout_p: float = 0.2): 6 | super(FeedForward, self).__init__( 7 | nn.Linear(input_dim, hidden_dim), 8 | nn.GELU(), 9 | nn.Dropout(dropout_p), 10 | nn.Linear(hidden_dim, output_dim), 11 | ) 12 | 13 | 14 | class DeepBiAffine(nn.Module): 15 | def __init__( 16 | self, 17 | input_dim: int, 18 | hidden_dim: int, 19 | output_dim: int, 20 | dropout_p: float = 0.2, 21 | feat_embed_dim: int = 0, 22 | ): 23 | super(DeepBiAffine, self).__init__() 24 | self.W_left = FeedForward(input_dim, hidden_dim, hidden_dim, dropout_p) 25 | self.W_right = FeedForward(input_dim, hidden_dim, hidden_dim, dropout_p) 26 | 27 | self.W_s = nn.Bilinear(hidden_dim, hidden_dim, output_dim) 28 | self.V_left = nn.Linear(hidden_dim, output_dim) 29 | self.V_right = nn.Linear(hidden_dim, output_dim) 30 | 31 | self.disable_feat = feat_embed_dim == 0 32 | if not self.disable_feat: 33 | self.W_feat = FeedForward(feat_embed_dim, 100, output_dim) 34 | 35 | def forward(self, h_ik, h_kj, feat=None): 36 | h_ik = self.W_left(h_ik) 37 | h_kj = self.W_right(h_kj) 38 | y = self.W_s(h_ik, h_kj) + self.V_left(h_ik) + self.V_right(h_kj) 39 | 40 | if not self.disable_feat: 41 | y_f = self.W_feat(feat) 42 | y = y + y_f 43 | 44 | return y 45 | -------------------------------------------------------------------------------- /src/models/classifier/shift_reduce_classifier_base.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from data.batch import Batch 4 | from data.doc import Doc 5 | from models.classifier import ClassifierBase 6 | 7 | 8 | class ShiftReduceClassifierBase(ClassifierBase): 9 | def __init__(self, hidden_dim: int, dropout_p: float = 0.2, *args, **kwargs): 10 | super(ShiftReduceClassifierBase, self).__init__(*args, **kwargs) 11 | self.save_hyperparameters() 12 | self.hidden_dim = hidden_dim 13 | self.dropout_p = dropout_p 14 | self.org_embed = self.init_org_embeddings() 15 | 16 | @classmethod 17 | def params_from_config(cls, config): 18 | params = super().params_from_config(config) 19 | params.update( 20 | { 21 | "hidden_dim": config.hidden_dim, 22 | "dropout_p": config.dropout_p, 23 | } 24 | ) 25 | return params 26 | 27 | def init_org_embeddings(self): 28 | if self.disable_org_feat: 29 | return None 30 | 31 | num_feat = 0 32 | if not self.disable_org_sent: 33 | num_feat += 17 34 | if not self.disable_org_para: 35 | num_feat += 11 36 | 37 | return nn.Embedding(num_feat * 2, 10) 38 | 39 | def forward(self, doc: Doc, spans: dict, feats: dict): 40 | raise NotImplementedError 41 | 42 | def training_loss(self, batch: Batch): 43 | doc = batch.doc 44 | spans = batch.span 45 | feats = batch.feat 46 | output = self(doc, spans, feats) 47 | 48 | loss_dict = self.compute_loss(output, batch) 49 | return loss_dict 50 | 51 | def compute_loss(self, output, batch): 52 | raise NotImplementedError 53 | 54 | def predict(self, document_embedding, span: dict, feat: dict): 55 | raise NotImplementedError 56 | -------------------------------------------------------------------------------- /src/models/classifier/shift_reduce_classifier_v1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from data.batch import Batch 5 | from data.doc import Doc 6 | from models.classifier import ShiftReduceClassifierBase 7 | from models.classifier.linear import FeedForward 8 | 9 | 10 | class ShiftReduceClassifierV1(ShiftReduceClassifierBase): 11 | def __init__(self, *args, **kwargs): 12 | super(ShiftReduceClassifierV1, self).__init__(*args, **kwargs) 13 | self.save_hyperparameters() 14 | 15 | self.act_vocab = self.DATASET.action_vocab 16 | self.nuc_vocab = self.DATASET.nucleus_vocab 17 | self.rel_vocab = self.DATASET.relation_vocab 18 | 19 | embed_dim = self.encoder.get_embed_dim() * 3 20 | feat_embed_dim = self.get_org_embedding_dim() 21 | embed_dim += feat_embed_dim 22 | 23 | self.out_linear_action = FeedForward( 24 | embed_dim, self.hidden_dim, len(self.act_vocab), self.dropout_p 25 | ) 26 | self.out_linear_nucleus = FeedForward( 27 | embed_dim, self.hidden_dim, len(self.nuc_vocab), self.dropout_p 28 | ) 29 | self.out_linear_relation = FeedForward( 30 | embed_dim, self.hidden_dim, len(self.rel_vocab), self.dropout_p 31 | ) 32 | 33 | assert self.act_vocab[""] == self.nuc_vocab[""] == self.rel_vocab[""] 34 | pad_idx = self.act_vocab[""] 35 | self.pad_idx = pad_idx 36 | self.xent_loss = nn.CrossEntropyLoss(ignore_index=pad_idx) 37 | 38 | def forward(self, doc: Doc, spans: dict, feats: dict): 39 | document_embedding = self.encoder(doc) 40 | span_embeddings = [] 41 | for span, feat in zip(spans, feats): 42 | s1_emb = self.encoder.get_span_embedding(document_embedding, span["s1"]) 43 | s2_emb = self.encoder.get_span_embedding(document_embedding, span["s2"]) 44 | q1_emb = self.encoder.get_span_embedding(document_embedding, span["q1"]) 45 | embedding = torch.cat((s1_emb, s2_emb, q1_emb), dim=0) 46 | 47 | if not self.disable_org_feat: 48 | org_emb = self.org_embed(feat["org"]).view(-1) 49 | embedding = torch.cat((embedding, org_emb), dim=0) 50 | 51 | span_embeddings.append(embedding) 52 | 53 | span_embeddings = torch.stack(span_embeddings, dim=0) 54 | # predict label scores for act_nuc and rel 55 | act_scores = self.out_linear_action(span_embeddings) 56 | nuc_scores = self.out_linear_nucleus(span_embeddings) 57 | rel_scores = self.out_linear_relation(span_embeddings) 58 | 59 | output = { 60 | "act_scores": act_scores, 61 | "nuc_scores": nuc_scores, 62 | "rel_scores": rel_scores, 63 | } 64 | return output 65 | 66 | def compute_loss(self, output, batch: Batch): 67 | labels = batch.label 68 | act_idx = labels["act"] 69 | nuc_idx = labels["nuc"] 70 | rel_idx = labels["rel"] 71 | act_loss = self.xent_loss(output["act_scores"], act_idx) 72 | nuc_loss = self.xent_loss(output["nuc_scores"], nuc_idx) 73 | rel_loss = self.xent_loss(output["rel_scores"], rel_idx) 74 | if torch.all(nuc_idx == self.pad_idx): 75 | # if action is shift, there are no nuc and relation labels 76 | # and xent_loss return NaN. 77 | nuc_loss = torch.zeros_like(nuc_loss) 78 | rel_loss = torch.zeros_like(rel_loss) 79 | 80 | loss = (act_loss + nuc_loss + rel_loss) / 3 81 | 82 | return { 83 | "loss": loss, 84 | "act_loss": act_loss, 85 | "nuc_loss": nuc_loss, 86 | "rel_loss": rel_loss, 87 | } 88 | 89 | def predict(self, document_embedding, span: dict, feat: dict): 90 | s1_emb = self.encoder.get_span_embedding(document_embedding, span["s1"]) 91 | s2_emb = self.encoder.get_span_embedding(document_embedding, span["s2"]) 92 | q1_emb = self.encoder.get_span_embedding(document_embedding, span["q1"]) 93 | embedding = torch.cat((s1_emb, s2_emb, q1_emb), dim=0) 94 | if not self.disable_org_feat: 95 | org_emb = self.org_embed(feat["org"]).view(-1) 96 | embedding = torch.cat((embedding, org_emb), dim=0) 97 | 98 | act_scores = self.out_linear_action(embedding) 99 | nuc_scores = self.out_linear_nucleus(embedding) 100 | rel_scores = self.out_linear_relation(embedding) 101 | return act_scores, nuc_scores, rel_scores 102 | -------------------------------------------------------------------------------- /src/models/classifier/shift_reduce_classifier_v2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from data.batch import Batch 5 | from data.doc import Doc 6 | from models.classifier import ShiftReduceClassifierBase 7 | from models.classifier.linear import FeedForward 8 | 9 | 10 | class ShiftReduceClassifierV2(ShiftReduceClassifierBase): 11 | def __init__(self, *args, **kwargs): 12 | super(ShiftReduceClassifierV2, self).__init__(*args, **kwargs) 13 | self.save_hyperparameters() 14 | 15 | self.act_vocab = self.DATASET.action_vocab 16 | self.ful_vocab = self.DATASET.fully_label_vocab 17 | 18 | embed_dim = self.encoder.get_embed_dim() * 3 19 | feat_embed_dim = self.get_org_embedding_dim() 20 | embed_dim += feat_embed_dim 21 | 22 | self.out_linear_act = FeedForward( 23 | embed_dim, self.hidden_dim, len(self.act_vocab), self.dropout_p 24 | ) 25 | self.out_linear_label = FeedForward( 26 | embed_dim, self.hidden_dim, len(self.ful_vocab), self.dropout_p 27 | ) 28 | 29 | assert self.act_vocab[""] == self.ful_vocab[""] 30 | pad_idx = self.act_vocab[""] 31 | self.pad_idx == pad_idx 32 | self.xent_loss = nn.CrossEntropyLoss(ignore_index=pad_idx) 33 | 34 | def forward(self, doc: Doc, spans: dict, feats: dict): 35 | document_embedding = self.encoder(doc) 36 | 37 | span_embeddings = [] 38 | for span, feat in zip(spans, feats): 39 | s1_emb = self.encoder.get_span_embedding(document_embedding, span["s1"]) 40 | s2_emb = self.encoder.get_span_embedding(document_embedding, span["s2"]) 41 | q1_emb = self.encoder.get_span_embedding(document_embedding, span["q1"]) 42 | embedding = torch.cat((s1_emb, s2_emb, q1_emb), dim=0) 43 | 44 | if not self.disable_org_feat: 45 | org_emb = self.org_embed(feat["org"]).view(-1) 46 | embedding = torch.cat((embedding, org_emb), dim=0) 47 | 48 | span_embeddings.append(embedding) 49 | 50 | span_embeddings = torch.stack(span_embeddings, dim=0) 51 | 52 | # predict label scores for act_nuc and rel 53 | act_scores = self.out_linear_act(span_embeddings) 54 | label_scores = self.out_linear_label(span_embeddings) 55 | 56 | output = { 57 | "act_scores": act_scores, 58 | "ful_scores": label_scores, 59 | } 60 | return output 61 | 62 | def compute_loss(self, output, batch: Batch): 63 | labels = batch.label 64 | act_idx = labels["act"] 65 | ful_idx = labels["ful"] 66 | act_loss = self.xent_loss(output["act_scores"], act_idx) 67 | ful_loss = self.xent_loss(output["ful_scores"], ful_idx) 68 | if torch.all(ful_idx == self.pad_idx): 69 | # if action is shift, there are no nuc and relation labels 70 | # and xent_loss return NaN. 71 | ful_loss = torch.zeros_like(ful_loss) 72 | 73 | loss = (act_loss + ful_loss) / 2 74 | 75 | return { 76 | "loss": loss, 77 | "act_loss": act_loss, 78 | "ful_loss": ful_loss, 79 | } 80 | 81 | def predict(self, document_embedding, span: dict, feat: dict): 82 | s1_emb = self.encoder.get_span_embedding(document_embedding, span["s1"]) 83 | s2_emb = self.encoder.get_span_embedding(document_embedding, span["s2"]) 84 | q1_emb = self.encoder.get_span_embedding(document_embedding, span["q1"]) 85 | embedding = torch.cat((s1_emb, s2_emb, q1_emb), dim=0) 86 | if not self.disable_org_feat: 87 | org_emb = self.org_embed(feat["org"]).view(-1) 88 | embedding = torch.cat((embedding, org_emb), dim=0) 89 | 90 | act_scores = self.out_linear_act(embedding) 91 | label_scores = self.out_linear_label(embedding) 92 | return act_scores, label_scores 93 | -------------------------------------------------------------------------------- /src/models/classifier/shift_reduce_classifier_v3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from data.batch import Batch 5 | from data.doc import Doc 6 | from models.classifier import ShiftReduceClassifierBase 7 | from models.classifier.linear import FeedForward 8 | 9 | 10 | class ShiftReduceClassifierV3(ShiftReduceClassifierBase): 11 | def __init__(self, disable_penalty: bool = False, *args, **kwargs): 12 | super(ShiftReduceClassifierV3, self).__init__(*args, **kwargs) 13 | self.save_hyperparameters() 14 | 15 | self.act_nuc_vocab = self.DATASET.act_nuc_vocab 16 | self.rel_vocab = self.DATASET.relation_vocab 17 | 18 | embed_dim = self.encoder.get_embed_dim() * 3 19 | feat_embed_dim = self.get_org_embedding_dim() 20 | embed_dim += feat_embed_dim 21 | 22 | self.out_linear_act_nuc = FeedForward( 23 | embed_dim, self.hidden_dim, len(self.act_nuc_vocab), self.dropout_p 24 | ) 25 | self.out_linear_relation = FeedForward( 26 | embed_dim, self.hidden_dim, len(self.rel_vocab), self.dropout_p 27 | ) 28 | 29 | self.disable_penalty = disable_penalty 30 | assert self.act_nuc_vocab[""] == self.rel_vocab[""] 31 | pad_idx = self.act_nuc_vocab[""] 32 | self.pad_idx = pad_idx 33 | 34 | self.act_nuc_weight = None 35 | if not self.disable_penalty: 36 | act2weight = { 37 | "shift_": 3 / 6, 38 | "reduce_nucleus-satellite": 1 / 6, 39 | "reduce_satellite-nucleus": 1 / 6, 40 | "reduce_nucleus-nucleus": 1 / 6, 41 | "": 0, 42 | } 43 | weight = torch.tensor([act2weight[act] for act in self.act_nuc_vocab.itos]) 44 | self.act_nuc_weight = nn.Parameter(weight, requires_grad=False) 45 | 46 | self.xent_act_nuc_loss = nn.CrossEntropyLoss(ignore_index=pad_idx, reduction="none") 47 | self.xent_relation_loss = nn.CrossEntropyLoss(ignore_index=pad_idx) 48 | 49 | @classmethod 50 | def params_from_config(cls, config): 51 | params = super().params_from_config(config) 52 | params.update( 53 | { 54 | "disable_penalty": config.disable_penalty, 55 | } 56 | ) 57 | return params 58 | 59 | def forward(self, doc: Doc, spans: dict, feats: dict): 60 | document_embedding = self.encoder(doc) 61 | 62 | span_embeddings = [] 63 | for span, feat in zip(spans, feats): 64 | s1_emb = self.encoder.get_span_embedding(document_embedding, span["s1"]) 65 | s2_emb = self.encoder.get_span_embedding(document_embedding, span["s2"]) 66 | q1_emb = self.encoder.get_span_embedding(document_embedding, span["q1"]) 67 | embedding = torch.cat((s1_emb, s2_emb, q1_emb), dim=0) 68 | 69 | if not self.disable_org_feat: 70 | org_emb = self.org_embed(feat["org"]).view(-1) 71 | embedding = torch.cat((embedding, org_emb), dim=0) 72 | 73 | span_embeddings.append(embedding) 74 | 75 | span_embeddings = torch.stack(span_embeddings, dim=0) 76 | 77 | # predict label scores for act_nuc and rel 78 | act_nuc_scores = self.out_linear_act_nuc(span_embeddings) 79 | rel_scores = self.out_linear_relation(span_embeddings) 80 | 81 | output = { 82 | "act_nuc_scores": act_nuc_scores, 83 | "rel_scores": rel_scores, 84 | } 85 | return output 86 | 87 | def compute_loss(self, output, batch: Batch): 88 | labels = batch.label 89 | act_nuc_idx = labels["act_nuc"] 90 | rel_idx = labels["rel"] 91 | act_nuc_scores = output["act_nuc_scores"] 92 | act_nuc_losses = self.xent_act_nuc_loss(act_nuc_scores, act_nuc_idx) 93 | 94 | # weighting for inbalance of shift-reduce action (Guz et al., 2020) 95 | if not self.disable_penalty: 96 | weight = self.act_nuc_weight[act_nuc_idx] 97 | act_nuc_losses = weight * act_nuc_losses 98 | 99 | act_nuc_loss = torch.mean(act_nuc_losses, dim=0) 100 | rel_loss = self.xent_relation_loss(output["rel_scores"], rel_idx) 101 | if torch.all(rel_idx == self.pad_idx): 102 | rel_loss = torch.zeros_like(rel_loss) 103 | 104 | loss = (act_nuc_loss + rel_loss) / 2 105 | 106 | return { 107 | "loss": loss, 108 | "act_nuc_loss": act_nuc_loss, 109 | "rel_loss": rel_loss, 110 | } 111 | 112 | def predict(self, document_embedding, span: dict, feat: dict): 113 | s1_emb = self.encoder.get_span_embedding(document_embedding, span["s1"]) 114 | s2_emb = self.encoder.get_span_embedding(document_embedding, span["s2"]) 115 | q1_emb = self.encoder.get_span_embedding(document_embedding, span["q1"]) 116 | embedding = torch.cat((s1_emb, s2_emb, q1_emb), dim=0) 117 | if not self.disable_org_feat: 118 | org_emb = self.org_embed(feat["org"]).view(-1) 119 | embedding = torch.cat((embedding, org_emb), dim=0) 120 | 121 | act_nuc_scores = self.out_linear_act_nuc(embedding) 122 | rel_scores = self.out_linear_relation(embedding) 123 | return act_nuc_scores, rel_scores 124 | 125 | def act_nuc_to_act_scores(self, act_nuc_scores: torch.Tensor): 126 | shift_idx = self.act_nuc_vocab["shift_"] 127 | reduce_idxs = [ 128 | self.act_nuc_vocab["reduce_{}".format(ns)] 129 | for ns in ["nucleus-satellite", "satellite-nucleus", "nucleus-nucleus"] 130 | ] 131 | 132 | # compute probability 133 | Z = sum( 134 | [act_nuc_scores[shift_idx].exp()] 135 | + [act_nuc_scores[reduce_idx].exp() for reduce_idx in reduce_idxs] 136 | ) 137 | shift_prob = act_nuc_scores[shift_idx].exp() / Z 138 | reduce_prob = 1 - shift_prob # (reduce_n-s, reduce_s-n, reduce_n-n) 139 | 140 | # compute log-likelihood for each action 141 | shift_log = shift_prob.clamp(min=1e-6).log() 142 | reduce_log = reduce_prob.clamp(min=1e-6).log() 143 | 144 | return {"shift": shift_log, "reduce": reduce_log} 145 | -------------------------------------------------------------------------------- /src/models/classifier/top_down_classifier_base.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from data.batch import Batch 7 | from data.doc import Doc 8 | from models.classifier import ClassifierBase 9 | from models.classifier.linear import DeepBiAffine 10 | 11 | 12 | class TopDownClassifierBase(ClassifierBase): 13 | def __init__( 14 | self, 15 | hidden_dim: int, 16 | dropout_p: float = 0.2, 17 | disable_penalty: bool = False, 18 | *args, 19 | **kwargs 20 | ): 21 | super(TopDownClassifierBase, self).__init__(*args, **kwargs) 22 | self.save_hyperparameters() 23 | self.hidden_dim = hidden_dim 24 | self.dropout_p = dropout_p 25 | self.disable_penalty = disable_penalty 26 | 27 | self.org_embed = self.init_org_embeddings() 28 | 29 | embed_dim = self.encoder.get_embed_dim() 30 | feat_embed_dim = self.get_org_embedding_dim() 31 | 32 | self.out_linear_split = DeepBiAffine( 33 | embed_dim, self.hidden_dim, 1, self.dropout_p, feat_embed_dim=feat_embed_dim 34 | ) 35 | 36 | self.xent_split_loss = nn.CrossEntropyLoss(reduction="none") 37 | 38 | @classmethod 39 | def params_from_config(cls, config): 40 | params = super().params_from_config(config) 41 | params.update( 42 | { 43 | "hidden_dim": config.hidden_dim, 44 | "dropout_p": config.dropout_p, 45 | "disable_penalty": config.disable_penalty, 46 | } 47 | ) 48 | return params 49 | 50 | def init_org_embeddings(self): 51 | if self.disable_org_feat: 52 | return None 53 | 54 | num_feat = 0 55 | if not self.disable_org_sent: 56 | num_feat += 10 57 | if not self.disable_org_para: 58 | num_feat += 6 59 | 60 | return nn.Embedding(num_feat * 2, 10) 61 | 62 | def forward(self, doc: Doc, spans: dict, feats: dict): 63 | raise NotImplementedError 64 | 65 | def training_loss(self, batch: Batch): 66 | doc = batch.doc 67 | spans = batch.span 68 | feats = batch.feat 69 | output = self(doc, spans, feats) 70 | 71 | loss_dict = self.compute_loss(output, batch) 72 | return loss_dict 73 | 74 | def compute_loss(self, output, batch: Batch): 75 | raise NotImplementedError 76 | 77 | def compute_split_loss(self, output, batch: Batch): 78 | labels = batch.label 79 | spans = batch.span 80 | 81 | spl_idx = labels["spl"] 82 | spl_losses = [ 83 | self.xent_split_loss(scores.unsqueeze(0), idx.unsqueeze(0)) 84 | for scores, idx in zip(output["spl_scores"], spl_idx) 85 | ] 86 | 87 | if not self.disable_penalty: 88 | # Segmentation loss with penalty (Koto et al., 2021) 89 | beta = 0.35 90 | spl_losses = [ 91 | (1 + (span["j"] - span["i"])) ** beta * loss 92 | for loss, span in zip(spl_losses, spans) 93 | ] 94 | 95 | spl_loss = torch.mean(torch.stack(spl_losses, dim=0)) 96 | return spl_loss 97 | 98 | def compute_label_loss(self, output, batch: Batch): 99 | raise NotImplementedError 100 | 101 | def predict_split( 102 | self, 103 | document_embedding, 104 | span: Tuple[int], 105 | doc: Doc, 106 | return_scores: bool = False, 107 | ): 108 | i, j = span 109 | left_embeddings, right_embeddings, org_indices = [], [], [] 110 | for k in range(i + 1, j): 111 | left_embeddings.append(self.encoder.get_span_embedding(document_embedding, (i, k))) 112 | right_embeddings.append(self.encoder.get_span_embedding(document_embedding, (k, j))) 113 | if not self.disable_org_feat: 114 | org_idx = self.parser.get_organization_features((i, k), (k, j), doc, self.device) 115 | org_indices.append(org_idx) 116 | 117 | left_embeddings = torch.stack(left_embeddings, dim=0) 118 | right_embeddings = torch.stack(right_embeddings, dim=0) 119 | 120 | org_embeddings = None 121 | if not self.disable_org_feat: 122 | n = len(org_indices) # num of split points 123 | org_indices = torch.stack(org_indices) 124 | org_embeddings = self.org_embed(org_indices).view(n, -1) 125 | 126 | split_scores = self.out_linear_split( 127 | left_embeddings, right_embeddings, org_embeddings 128 | ).squeeze(1) 129 | 130 | if return_scores: 131 | return split_scores 132 | 133 | k = torch.argmax(split_scores) + span[0] + 1 134 | return k.item() 135 | 136 | def predict_label( 137 | self, 138 | document_embedding, 139 | left_span: Tuple[int], 140 | right_span: Tuple[int], 141 | doc: Doc, 142 | ): 143 | raise NotImplementedError 144 | 145 | def predict_split_fast( 146 | self, 147 | document_embedding, 148 | span: Tuple[int], 149 | doc: Doc, 150 | span_to_org_indices: Dict, 151 | return_scores: bool = False, 152 | ): 153 | i, j = span 154 | left_spans = [] 155 | right_spans = [] 156 | for k in range(i + 1, j): 157 | left_spans.append((i, k)) 158 | right_spans.append((k, j)) 159 | 160 | left_embeddings = self.encoder.batch_get_span_embedding(document_embedding, left_spans) 161 | right_embeddings = self.encoder.batch_get_span_embedding(document_embedding, right_spans) 162 | 163 | if not self.disable_org_feat: 164 | org_indices = span_to_org_indices[span] 165 | 166 | org_embeddings = None 167 | if not self.disable_org_feat: 168 | n = len(org_indices) # num of split points 169 | org_indices = torch.stack(org_indices).to(self.device) 170 | org_embeddings = self.org_embed(org_indices).view(n, -1) 171 | 172 | split_scores = self.out_linear_split( 173 | left_embeddings, right_embeddings, org_embeddings 174 | ).squeeze(1) 175 | 176 | if return_scores: 177 | return split_scores 178 | 179 | k = torch.argmax(split_scores) + span[0] + 1 180 | return k.item() 181 | -------------------------------------------------------------------------------- /src/models/classifier/top_down_classifier_v1.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from data.batch import Batch 7 | from data.doc import Doc 8 | from models.classifier import TopDownClassifierBase 9 | from models.classifier.linear import DeepBiAffine 10 | 11 | 12 | class TopDownClassifierV1(TopDownClassifierBase): 13 | def __init__(self, *args, **kwargs): 14 | super(TopDownClassifierV1, self).__init__(*args, **kwargs) 15 | self.save_hyperparameters() 16 | 17 | self.nuc_vocab = self.DATASET.nucleus_vocab 18 | self.rel_vocab = self.DATASET.relation_vocab 19 | 20 | embed_dim = self.encoder.get_embed_dim() 21 | feat_embed_dim = self.get_org_embedding_dim() 22 | 23 | self.out_linear_nucleus = DeepBiAffine( 24 | embed_dim, 25 | self.hidden_dim, 26 | len(self.nuc_vocab), 27 | self.dropout_p, 28 | feat_embed_dim, 29 | ) 30 | self.out_linear_relation = DeepBiAffine( 31 | embed_dim, 32 | self.hidden_dim, 33 | len(self.rel_vocab), 34 | self.dropout_p, 35 | feat_embed_dim, 36 | ) 37 | 38 | assert self.nuc_vocab[""] == self.rel_vocab[""] 39 | pad_idx = self.nuc_vocab[""] 40 | self.xent_label_loss = nn.CrossEntropyLoss(ignore_index=pad_idx) 41 | 42 | def forward(self, doc: Doc, spans: dict, feats: dict): 43 | document_embedding = self.encoder(doc) 44 | 45 | spl_scores = [] 46 | left_embeddings, right_embeddings, org_embeddings = [], [], [] 47 | for span in spans: 48 | i, j, k = span["i"], span["j"], span["k"] 49 | # predict split scores 50 | spl_scores.append( 51 | self.predict_split(document_embedding, (i, j), doc, return_scores=True) 52 | ) 53 | 54 | left_emb = self.encoder.get_span_embedding(document_embedding, (i, k)) 55 | right_emb = self.encoder.get_span_embedding(document_embedding, (k, j)) 56 | left_embeddings.append(left_emb) 57 | right_embeddings.append(right_emb) 58 | if not self.disable_org_feat: 59 | org_emb = self.org_embed( 60 | self.parser.get_organization_features((i, k), (k, j), doc, self.device) 61 | ).view(-1) 62 | org_embeddings.append(org_emb) 63 | 64 | left_embeddings = torch.stack(left_embeddings, dim=0) 65 | right_embeddings = torch.stack(right_embeddings, dim=0) 66 | org_embeddings = None if self.disable_org_feat else torch.stack(org_embeddings, dim=0) 67 | 68 | # predict label scores for nuc and rel 69 | nuc_scores = self.out_linear_nucleus(left_embeddings, right_embeddings, org_embeddings) 70 | rel_scores = self.out_linear_relation(left_embeddings, right_embeddings, org_embeddings) 71 | 72 | output = { 73 | "spl_scores": spl_scores, 74 | "nuc_scores": nuc_scores, 75 | "rel_scores": rel_scores, 76 | } 77 | return output 78 | 79 | def compute_loss(self, output, batch: Batch): 80 | spl_loss = self.compute_split_loss(output, batch) 81 | nuc_loss, rel_loss = self.compute_label_loss(output, batch) 82 | loss = (spl_loss + nuc_loss + rel_loss) / 3 83 | 84 | return { 85 | "loss": loss, 86 | "spl_loss": spl_loss, 87 | "nuc_loss": nuc_loss, 88 | "rel_loss": rel_loss, 89 | } 90 | 91 | def compute_label_loss(self, output, batch: Batch): 92 | labels = batch.label 93 | nuc_idx = labels["nuc"] 94 | rel_idx = labels["rel"] 95 | nuc_loss = self.xent_label_loss(output["nuc_scores"], nuc_idx) 96 | rel_loss = self.xent_label_loss(output["rel_scores"], rel_idx) 97 | return nuc_loss, rel_loss 98 | 99 | def predict_label( 100 | self, 101 | document_embedding, 102 | left_span: Tuple[int], 103 | right_span: Tuple[int], 104 | doc: Doc, 105 | return_scores: bool = False, 106 | ): 107 | left_emb = self.encoder.get_span_embedding(document_embedding, left_span) 108 | right_emb = self.encoder.get_span_embedding(document_embedding, right_span) 109 | org_emb = None 110 | if not self.disable_org_feat: 111 | org_emb = self.org_embed( 112 | self.parser.get_organization_features(left_span, right_span, doc, self.device) 113 | ).view(-1) 114 | 115 | nuc_scores = self.out_linear_nucleus(left_emb, right_emb, org_emb) 116 | rel_scores = self.out_linear_relation(left_emb, right_emb, org_emb) 117 | 118 | if return_scores: 119 | return nuc_scores, rel_scores 120 | 121 | nuc_scores[self.nuc_vocab[""]] = -float("inf") 122 | rel_scores[self.rel_vocab[""]] = -float("inf") 123 | nuc_label = self.nuc_vocab.lookup_token(torch.argmax(nuc_scores)) 124 | rel_label = self.rel_vocab.lookup_token(torch.argmax(rel_scores)) 125 | label = ":".join([nuc_label, rel_label]) 126 | return label 127 | -------------------------------------------------------------------------------- /src/models/classifier/top_down_classifier_v2.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from data.batch import Batch 7 | from data.doc import Doc 8 | from models.classifier import TopDownClassifierBase 9 | from models.classifier.linear import DeepBiAffine 10 | 11 | 12 | class TopDownClassifierV2(TopDownClassifierBase): 13 | def __init__(self, *args, **kwargs): 14 | super(TopDownClassifierV2, self).__init__(*args, **kwargs) 15 | self.save_hyperparameters() 16 | 17 | self.ful_vocab = self.DATASET.fully_label_vocab 18 | 19 | embed_dim = self.encoder.get_embed_dim() 20 | feat_embed_dim = self.get_org_embedding_dim() 21 | 22 | self.out_linear_label = DeepBiAffine( 23 | embed_dim, 24 | self.hidden_dim, 25 | len(self.ful_vocab), 26 | self.dropout_p, 27 | feat_embed_dim, 28 | ) 29 | 30 | pad_idx = self.ful_vocab[""] 31 | self.xent_label_loss = nn.CrossEntropyLoss(ignore_index=pad_idx) 32 | 33 | def forward(self, doc: Doc, spans: dict, feats: dict): 34 | document_embedding = self.encoder(doc) 35 | 36 | spl_scores = [] 37 | left_embeddings, right_embeddings, org_embeddings = [], [], [] 38 | for span in spans: 39 | i, j, k = span["i"], span["j"], span["k"] 40 | # predict split scores 41 | spl_scores.append( 42 | self.predict_split(document_embedding, (i, j), doc, return_scores=True) 43 | ) 44 | 45 | left_emb = self.encoder.get_span_embedding(document_embedding, (i, k)) 46 | right_emb = self.encoder.get_span_embedding(document_embedding, (k, j)) 47 | left_embeddings.append(left_emb) 48 | right_embeddings.append(right_emb) 49 | if not self.disable_org_feat: 50 | org_emb = self.org_embed( 51 | self.parser.get_organization_features((i, k), (k, j), doc, self.device) 52 | ).view(-1) 53 | org_embeddings.append(org_emb) 54 | 55 | left_embeddings = torch.stack(left_embeddings, dim=0) 56 | right_embeddings = torch.stack(right_embeddings, dim=0) 57 | org_embeddings = None if self.disable_org_feat else torch.stack(org_embeddings, dim=0) 58 | 59 | # predict label scores for nuc and rel 60 | label_scores = self.out_linear_label(left_embeddings, right_embeddings, org_embeddings) 61 | 62 | output = { 63 | "spl_scores": spl_scores, 64 | "ful_scores": label_scores, 65 | } 66 | return output 67 | 68 | def compute_loss(self, output, batch: Batch): 69 | spl_loss = self.compute_split_loss(output, batch) 70 | ful_loss = self.compute_label_loss(output, batch) 71 | loss = (spl_loss + ful_loss) / 2 72 | return { 73 | "loss": loss, 74 | "spl_loss": spl_loss, 75 | "full_loss": ful_loss, 76 | } 77 | 78 | def compute_label_loss(self, output, batch: Batch): 79 | labels = batch.label 80 | ful_idx = labels["ful"] 81 | ful_loss = self.xent_label_loss(output["ful_scores"], ful_idx) 82 | return ful_loss 83 | 84 | def predict_label( 85 | self, 86 | document_embedding, 87 | left_span: Tuple[int], 88 | right_span: Tuple[int], 89 | doc: Doc, 90 | return_scores: bool = False, 91 | ): 92 | left_emb = self.encoder.get_span_embedding(document_embedding, left_span) 93 | right_emb = self.encoder.get_span_embedding(document_embedding, right_span) 94 | org_emb = None 95 | if not self.disable_org_feat: 96 | org_emb = self.org_embed( 97 | self.parser.get_organization_features(left_span, right_span, doc, self.device) 98 | ).view(-1) 99 | 100 | ful_scores = self.out_linear_label(left_emb, right_emb, org_emb) 101 | if return_scores: 102 | return ful_scores 103 | 104 | ful_scores[self.ful_vocab[""]] = -float("inf") 105 | label = self.ful_vocab.lookup_token(torch.argmax(ful_scores)) 106 | return label 107 | -------------------------------------------------------------------------------- /src/models/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from models.encoder.encoder import Encoder 2 | from models.encoder.bert_encoder import BertEncoder 3 | 4 | 5 | class Encoders: 6 | encoder_dict = { 7 | "bert": BertEncoder, 8 | } 9 | 10 | @classmethod 11 | def from_config(cls, config): 12 | encoder_type = config.encoder_type 13 | encoder = cls.encoder_dict[encoder_type].from_config(config) 14 | return encoder 15 | 16 | 17 | __all__ = ["Encoder", "BertEncoder"] 18 | -------------------------------------------------------------------------------- /src/models/encoder/bert_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoConfig, AutoModel, AutoTokenizer 3 | 4 | from data.doc import Doc 5 | from models.encoder import Encoder 6 | 7 | 8 | def complete_model_name(model_name): 9 | if model_name.startswith("spanbert"): 10 | return "SpanBERT/" + model_name 11 | elif model_name.startswith("electra"): 12 | return "google/" + model_name 13 | elif model_name.startswith("mpnet"): 14 | return "microsoft/" + model_name 15 | elif model_name.startswith("deberta"): 16 | return "microsoft/" + model_name 17 | else: 18 | return model_name 19 | 20 | 21 | class BertEncoder(Encoder): 22 | def __init__(self, model_name: str, max_length: int = 512, stride: int = 30): 23 | super(BertEncoder, self).__init__() 24 | self.max_length = max_length 25 | self.stride = stride 26 | 27 | model_name = complete_model_name(model_name) 28 | self.config = AutoConfig.from_pretrained(model_name) 29 | self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) 30 | self.tokenizer.deprecation_warnings[ 31 | "sequence-length-is-longer-than-the-specified-maximum" 32 | ] = True 33 | self.model = AutoModel.from_pretrained(model_name) 34 | 35 | @classmethod 36 | def from_config(cls, config): 37 | params = { 38 | "model_name": config.bert_model_name, 39 | "max_length": config.bert_max_length, 40 | "stride": config.bert_stride, 41 | } 42 | return cls(**params) 43 | 44 | def get_embed_dim(self): 45 | return self.config.hidden_size 46 | 47 | def apply_tokenizer(self, doc: Doc): 48 | edu_strings = doc.get_edu_strings() 49 | 50 | raw_document = " ".join(edu_strings) 51 | inputs = self.tokenizer( 52 | raw_document, 53 | max_length=self.max_length, 54 | stride=self.stride, 55 | padding=True, 56 | truncation=True, 57 | return_overflowing_tokens=True, 58 | return_offsets_mapping=True, 59 | return_special_tokens_mask=True, 60 | return_token_type_ids=True, 61 | return_tensors="pt", 62 | ) 63 | 64 | flatten_inputs = self.tokenizer( 65 | raw_document, 66 | add_special_tokens=False, 67 | return_attention_mask=False, 68 | return_offsets_mapping=True, 69 | return_tensors="pt", 70 | ) 71 | 72 | reconst_edu_strings = [ 73 | self.tokenizer.decode(self.tokenizer.encode(edu, add_special_tokens=False)) 74 | for edu in edu_strings 75 | ] 76 | 77 | input_ids = flatten_inputs.input_ids[0] 78 | edu_offset = 0 79 | token_offset = 0 80 | buf = [] 81 | edu_to_subtokens_mappings = [] 82 | for token_id in input_ids: 83 | buf.append(token_id) 84 | edu_from_tokens = self.tokenizer.decode(buf).strip() 85 | edu = reconst_edu_strings[edu_offset] 86 | c_edu_from_tokens = edu_from_tokens.replace(" ", "").lower() 87 | c_edu = edu.replace(" ", "").lower() 88 | if c_edu_from_tokens == c_edu: # check a charactor level matching 89 | edu_to_subtokens_mappings.append([token_offset, token_offset + len(buf)]) 90 | edu_offset += 1 91 | token_offset = token_offset + len(buf) 92 | buf = [] 93 | elif len(c_edu_from_tokens) > len(c_edu): 94 | raise ValueError('"{}" != "{}"'.format(edu_from_tokens, edu)) 95 | 96 | # check num of edus and mappings 97 | assert len(edu_to_subtokens_mappings) == len(edu_strings) 98 | inputs["edu_to_subtokens_mappings"] = torch.tensor( 99 | edu_to_subtokens_mappings, dtype=torch.long 100 | ) 101 | return inputs 102 | 103 | def forward(self, doc: Doc): 104 | if not hasattr(doc, "inputs") or doc.inputs is None: 105 | doc.inputs = self.apply_tokenizer(doc).to(self.model.device) 106 | 107 | inputs = doc.inputs 108 | 109 | # run bert model 110 | outputs = self.model( 111 | input_ids=inputs["input_ids"], 112 | token_type_ids=inputs["token_type_ids"], 113 | attention_mask=inputs["attention_mask"], 114 | ) 115 | 116 | # fix a effects of max_length and stride. 117 | input_ids = [] 118 | embeddings = [] 119 | for idx, (_input_ids, _embeddings, attention_mask, special_tokens_mask,) in enumerate( 120 | zip( 121 | inputs["input_ids"], 122 | outputs.last_hidden_state, 123 | inputs["attention_mask"], 124 | inputs["special_tokens_mask"], 125 | ) 126 | ): 127 | 128 | # at the first, trim special tokens (sep, cls) 129 | normal_token_indices = torch.where(special_tokens_mask == 0) 130 | _input_ids = _input_ids[normal_token_indices] 131 | _embeddings = _embeddings[normal_token_indices] 132 | if idx == 0: 133 | input_ids.append(_input_ids) 134 | embeddings.append(_embeddings) 135 | else: 136 | # at the second, trim strided tokens 137 | input_ids.append(_input_ids[self.stride :]) 138 | embeddings.append(_embeddings[self.stride :]) 139 | 140 | input_ids = torch.cat(input_ids, dim=0) 141 | embeddings = torch.cat(embeddings, dim=0) 142 | 143 | bert_output = { 144 | "input_ids": input_ids, 145 | "embeddings": embeddings, 146 | "edu_to_subtokens_mappings": inputs.edu_to_subtokens_mappings, 147 | # 'edu_strings': doc.get_edu_strings(), 148 | # 'subtokens': self.tokenizer.convert_ids_to_tokens(input_ids), 149 | } 150 | # self.check_mapping(bert_output, self.tokenizer) 151 | return bert_output 152 | 153 | def get_span_embedding(self, bert_output, edu_span): 154 | edu_to_subtokens_mappings = bert_output["edu_to_subtokens_mappings"] 155 | subtoken_embeddings = bert_output["embeddings"] 156 | 157 | if edu_span == (-1, -1): 158 | return torch.zeros(self.get_embed_dim(), device=self.model.device) 159 | 160 | i = edu_to_subtokens_mappings[edu_span[0]][0] 161 | j = edu_to_subtokens_mappings[edu_span[1] - 1][1] 162 | embedding = (subtoken_embeddings[i] + subtoken_embeddings[j - 1]) / 2 163 | return embedding 164 | 165 | def batch_get_span_embedding(self, bert_output, edu_spans): 166 | edu_to_subtokens_mappings = bert_output["edu_to_subtokens_mappings"] 167 | subtoken_embeddings = bert_output["embeddings"] 168 | 169 | l, r = [], [] 170 | for span in edu_spans: 171 | l.append(span[0]) 172 | r.append(span[1] - 1) 173 | 174 | i = edu_to_subtokens_mappings[l][:, 0] 175 | j = edu_to_subtokens_mappings[r][:, 1] 176 | embedding = (subtoken_embeddings[i] + subtoken_embeddings[j - 1]) / 2 177 | return embedding 178 | 179 | @staticmethod 180 | def check_mapping(bert_output, tokenizer=None): 181 | edus = bert_output["edu_strings"] 182 | subtokens = bert_output["subtokens"] 183 | edu_to_subtokens_mappings = bert_output["edu_to_subtokens_mappings"] 184 | 185 | for edu_idx, edu_string in enumerate(edus): 186 | print("-" * 20) 187 | print(edu_string) 188 | subtoken_span = edu_to_subtokens_mappings[edu_idx] 189 | subtokens_in_edu = [] 190 | for j, subtoken_idx in enumerate(range(*subtoken_span)): 191 | subtoken = subtokens[subtoken_idx] 192 | print(subtoken, end=" ") 193 | subtokens_in_edu.append(subtoken) 194 | 195 | print() 196 | 197 | if tokenizer: 198 | string = tokenizer.convert_tokens_to_string(subtokens_in_edu) 199 | print(string) 200 | assert string.strip() == edu_string 201 | 202 | return 203 | -------------------------------------------------------------------------------- /src/models/encoder/encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from data.doc import Doc 4 | 5 | 6 | class Encoder(nn.Module): 7 | def __init__(self): 8 | super(Encoder, self).__init__() 9 | 10 | @classmethod 11 | def from_config(cls, config): 12 | raise NotImplementedError 13 | 14 | def forward(self, doc: Doc): 15 | raise NotImplementedError 16 | 17 | def get_embed_dim(self): 18 | raise NotImplementedError 19 | 20 | def get_span_embedding(self, encoder_output, span): 21 | raise NotImplementedError 22 | -------------------------------------------------------------------------------- /src/models/parser/__init__.py: -------------------------------------------------------------------------------- 1 | from models.parser.parser_base import ParserBase 2 | from models.parser.shift_reduce_parser_base import ShiftReduceParserBase 3 | from models.parser.shift_reduce_parser_v1 import ShiftReduceParserV1 4 | from models.parser.shift_reduce_parser_v2 import ShiftReduceParserV2 5 | from models.parser.shift_reduce_parser_v3 import ShiftReduceParserV3 6 | from models.parser.top_down_parser_base import TopDownParserBase 7 | from models.parser.top_down_parser_v1 import TopDownParserV1 8 | from models.parser.top_down_parser_v2 import TopDownParserV2 9 | 10 | 11 | class Parsers: 12 | parser_dict = { 13 | "top_down_v1": TopDownParserV1, 14 | "top_down_v2": TopDownParserV2, 15 | "shift_reduce_v1": ShiftReduceParserV1, 16 | "shift_reduce_v2": ShiftReduceParserV2, 17 | "shift_reduce_v3": ShiftReduceParserV3, 18 | } 19 | 20 | @classmethod 21 | def from_config(cls, config, classifier): 22 | parser_type = config.model_type 23 | parser = cls.parser_dict[parser_type](classifier) 24 | return parser 25 | 26 | 27 | __all__ = [ 28 | "ParserBase", 29 | "TopDownParserBase", 30 | "TopDownParserV1", 31 | "TopDownParserV2", 32 | "ShiftReduceParserBase", 33 | "ShiftReduceParserV1", 34 | "ShiftReduceParserV2", 35 | "ShiftReduceParserV3", 36 | ] 37 | -------------------------------------------------------------------------------- /src/models/parser/organization_feature.py: -------------------------------------------------------------------------------- 1 | class OrganizationFeature: 2 | @staticmethod 3 | def IsSameSent(edus_a, edus_b): 4 | if edus_a == [] or edus_b == []: 5 | return False 6 | return edus_a[0].sent_idx == edus_b[-1].sent_idx 7 | 8 | @staticmethod 9 | def IsContinueSent(edus_a, edus_b): 10 | if edus_a == [] or edus_b == []: 11 | return False 12 | return edus_a[-1].sent_idx == edus_b[0].sent_idx 13 | 14 | @staticmethod 15 | def IsSamePara(edus_a, edus_b): 16 | if edus_a == [] or edus_b == []: 17 | return False 18 | return edus_a[0].para_idx == edus_b[-1].para_idx 19 | 20 | @staticmethod 21 | def IsContinuePara(edus_a, edus_b): 22 | if edus_a == [] or edus_b == []: 23 | return False 24 | return edus_a[-1].para_idx == edus_b[0].para_idx 25 | 26 | @staticmethod 27 | def IsStartSent(edus): 28 | if edus == []: 29 | return False 30 | return edus[0].is_start_sent 31 | 32 | @staticmethod 33 | def IsStartPara(edus): 34 | if edus == []: 35 | return False 36 | return edus[0].is_start_para 37 | 38 | @staticmethod 39 | def IsStartDoc(edus): 40 | if edus == []: 41 | return False 42 | return edus[0].is_start_doc 43 | 44 | @staticmethod 45 | def IsEndSent(edus): 46 | if edus == []: 47 | return False 48 | return edus[-1].is_end_sent 49 | 50 | @staticmethod 51 | def IsEndPara(edus): 52 | if edus == []: 53 | return False 54 | return edus[-1].is_end_para 55 | 56 | @staticmethod 57 | def IsEndDoc(edus): 58 | if edus == []: 59 | return False 60 | return edus[-1].is_end_doc 61 | -------------------------------------------------------------------------------- /src/models/parser/parser_base.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Union 3 | 4 | from data.dataset import Dataset 5 | from data.doc import Doc 6 | from data.tree import AttachTree, RSTTree 7 | from metrics import OriginalParseval 8 | from models.classifier import ClassifierBase 9 | 10 | 11 | class ParserBase(object): 12 | def __init__(self, classifier: ClassifierBase): 13 | self.classifier = classifier 14 | 15 | def parse_dataset( 16 | self, 17 | dataset, 18 | verbose: bool = False, 19 | ): 20 | if verbose: 21 | print("start parsing a dataset") 22 | ss = time.time() 23 | 24 | result = {"doc_id": [], "pred_tree": [], "gold_tree": []} 25 | for batch in dataset: 26 | assert batch.unit_type == "document" 27 | doc = batch.doc 28 | tree = self.parse(doc) 29 | 30 | result["doc_id"].append(doc.doc_id) 31 | result["pred_tree"].append(tree) 32 | result["gold_tree"].append(doc.tree) 33 | 34 | if verbose: 35 | print("elapsed time for all: {:.2f} [sec]".format(time.time() - ss)) 36 | 37 | return result 38 | 39 | def parse(self, doc: Doc): 40 | raise NotImplementedError 41 | 42 | def parse_dataset_with_naked_tree( 43 | self, 44 | dataset, 45 | verbose: bool = False, 46 | ): 47 | if verbose: 48 | print("start parsing a dataset with naked tree") 49 | ss = time.time() 50 | 51 | result = {"doc_id": [], "pred_tree": [], "gold_tree": []} 52 | for batch in dataset: 53 | assert batch.unit_type == "document" 54 | doc = batch.doc 55 | tree = self.parse_with_naked_tree(doc, doc.tree) 56 | 57 | result["doc_id"].append(doc.doc_id) 58 | result["pred_tree"].append(tree) 59 | result["gold_tree"].append(doc.tree) 60 | 61 | if verbose: 62 | print("elapsed time for all: {:.2f} [sec]".format(time.time() - ss)) 63 | 64 | return result 65 | 66 | def parse_with_naked_tree( 67 | self, 68 | doc: Doc, 69 | tree: Union[RSTTree, AttachTree], 70 | ): 71 | raise NotImplementedError 72 | 73 | def parse_dataset_topk(self, dataset, topk: int, verbose: bool = False): 74 | if verbose: 75 | print("start parsing a dataset with top-k (k={})".format(topk)) 76 | ss = time.time() 77 | 78 | result = {"doc_id": [], "pred_tree": [], "pred_trees": [], "gold_tree": []} 79 | metric = OriginalParseval() 80 | for batch in dataset: 81 | assert batch.unit_type == "document" 82 | doc = batch.doc 83 | # if verbose: 84 | # print('document id: ', doc.doc_id) 85 | # print('- # of edus : {}'.format(len(doc.edus))) 86 | # s = time.time() 87 | 88 | trees = self.parse_topk(doc, topk) 89 | 90 | best_tree, best_score = None, -1 91 | for tree in trees: 92 | metric.update([tree], [doc.tree]) 93 | scores = metric.compute() 94 | score = scores["OriginalParseval-F"].item() 95 | if score > best_score: 96 | best_tree = tree 97 | best_score = score 98 | 99 | metric.reset() 100 | 101 | # if verbose: 102 | # print('- best score : {:.2f}'.format(best_score)) 103 | # print('- elapsed time: {:.2f} [sec]'.format(time.time() - s)) 104 | 105 | result["doc_id"].append(doc.doc_id) 106 | result["pred_tree"].append(best_tree) 107 | result["pred_trees"].append(trees) 108 | result["gold_tree"].append(doc.tree) 109 | 110 | if verbose: 111 | print("elapsed time for all: {:.2f} [sec]".format(time.time() - ss)) 112 | 113 | return result 114 | 115 | def parse_topk(self, doc: Doc): 116 | raise NotImplementedError 117 | 118 | def generate_training_samples(cls, dataset: Dataset, level: str): 119 | raise NotImplementedError 120 | 121 | def get_organization_features(self): 122 | raise NotImplementedError 123 | -------------------------------------------------------------------------------- /src/models/parser/shift_reduce_parser_base.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | 5 | from data.doc import Doc 6 | from data.tree import AttachTree 7 | from models.classifier import ShiftReduceClassifierBase 8 | from models.parser import ParserBase 9 | from models.parser.organization_feature import OrganizationFeature as OrgFeat 10 | from models.parser.shift_reduce_state import ShiftReduceState 11 | 12 | 13 | class ShiftReduceParserBase(ParserBase): 14 | """Basic shift-reduce parser.""" 15 | 16 | def __init__(self, classifier: ShiftReduceClassifierBase): 17 | super(ShiftReduceParserBase, self).__init__(classifier) 18 | assert isinstance(classifier, ShiftReduceClassifierBase) 19 | 20 | def select_action_and_labels(self, bert_output, span, feat, state, gold_act=None): 21 | raise NotImplementedError 22 | 23 | def parse(self, doc: Doc): 24 | bert_output = self.classifier.encoder(doc) 25 | n_edus = len(doc.edus) 26 | 27 | state = ShiftReduceState(n_edus) 28 | while not state.is_end(): 29 | s1, s2, q1 = state.get_state() 30 | span = {"s1": s1, "s2": s2, "q1": q1} 31 | feat = {"org": self.get_organization_features(s1, s2, q1, doc, self.classifier.device)} 32 | 33 | # predict action and labels 34 | act, nuc, rel = self.select_action_and_labels(bert_output, span, feat, state) 35 | 36 | # update stack and queue 37 | state.operate(act, nuc, rel) 38 | 39 | tree = state.get_tree() 40 | return tree 41 | 42 | def parse_topk(self, doc: Doc, k: int): 43 | return self.BEAM(doc, k) 44 | 45 | def BEAM(self, doc: Doc, top_k: int): 46 | raise NotImplementedError 47 | 48 | def parse_with_naked_tree(self, doc: Doc, naked_tree: AttachTree): 49 | return self.labeling_to_naked_tree(doc, naked_tree) 50 | 51 | def labeling_to_naked_tree(self, doc: Doc, tree: AttachTree): 52 | bert_output = self.classifier.encoder(doc) 53 | n_edus = len(doc.edus) 54 | 55 | act_list, _, _ = self.generate_action_sequence(tree) 56 | 57 | state = ShiftReduceState(n_edus) 58 | for gold_act in act_list: 59 | s1, s2, q1 = state.get_state() 60 | span = {"s1": s1, "s2": s2, "q1": q1} 61 | feat = {"org": self.get_organization_features(s1, s2, q1, doc, self.classifier.device)} 62 | 63 | # predict action and labels 64 | act, nuc, rel = self.select_action_and_labels( 65 | bert_output, span, feat, state, gold_act=gold_act 66 | ) 67 | 68 | # update stack and queue 69 | state.operate(gold_act, nuc, rel) 70 | 71 | tree = state.get_tree() 72 | return tree 73 | 74 | def generate_action_sequence(self, tree: AttachTree): 75 | act_list, nuc_list, rel_list = [], [], [] 76 | for tp in tree.treepositions("postorder"): 77 | node = tree[tp] 78 | if not isinstance(node, AttachTree): 79 | continue 80 | 81 | label = node.label() 82 | 83 | if len(node) == 1 and label == "text": 84 | # terminal node 85 | act_list.append("shift") 86 | nuc_list.append("") 87 | rel_list.append("") 88 | elif len(node) == 2: 89 | # non-terminal node 90 | nuc, rel = node.label().split(":", maxsplit=1) 91 | act_list.append("reduce") 92 | nuc_list.append(nuc) 93 | rel_list.append(rel) 94 | else: 95 | raise ValueError("Input tree is not binarized.") 96 | 97 | return act_list, nuc_list, rel_list 98 | 99 | def get_organization_features( 100 | self, s1: Tuple[int], s2: Tuple[int], q1: Tuple[int], doc: Doc, device=None 101 | ): 102 | # span == (-1, -1) -> edus = [] 103 | edus = doc.edus 104 | s1_edus = edus[slice(*s1)] 105 | s2_edus = edus[slice(*s2)] 106 | q1_edus = edus[slice(*q1)] 107 | 108 | # init features 109 | features = [] 110 | 111 | if not self.classifier.disable_org_sent: 112 | # for Stack 1 and Stack2 113 | features.append(OrgFeat.IsSameSent(s2_edus, s1_edus)) 114 | features.append(OrgFeat.IsContinueSent(s2_edus, s1_edus)) 115 | 116 | # for Stack 1 and Queue 1 117 | features.append(OrgFeat.IsSameSent(s1_edus, q1_edus)) 118 | features.append(OrgFeat.IsContinueSent(s1_edus, q1_edus)) 119 | 120 | # for Stack 1, 2 and Queue 1 121 | features.append( 122 | OrgFeat.IsSameSent(s2_edus, s1_edus) & OrgFeat.IsSameSent(s1_edus, q1_edus) 123 | ) 124 | 125 | # starts and ends a sentence 126 | features.append(OrgFeat.IsStartSent(s1_edus)) 127 | features.append(OrgFeat.IsStartSent(s2_edus)) 128 | features.append(OrgFeat.IsStartSent(q1_edus)) 129 | features.append(OrgFeat.IsEndSent(s1_edus)) 130 | features.append(OrgFeat.IsEndSent(s2_edus)) 131 | features.append(OrgFeat.IsEndSent(q1_edus)) 132 | 133 | # starts and ends a document 134 | features.append(OrgFeat.IsStartDoc(s1_edus)) 135 | features.append(OrgFeat.IsStartDoc(s2_edus)) 136 | features.append(OrgFeat.IsStartDoc(q1_edus)) 137 | features.append(OrgFeat.IsEndDoc(s1_edus)) 138 | features.append(OrgFeat.IsEndDoc(s2_edus)) 139 | features.append(OrgFeat.IsEndDoc(q1_edus)) 140 | 141 | if not self.classifier.disable_org_para: 142 | # for Stack 1 and Stack2 143 | features.append(OrgFeat.IsSamePara(s2_edus, s1_edus)) 144 | features.append(OrgFeat.IsContinuePara(s2_edus, s1_edus)) 145 | 146 | # for Stack 1 and Queue 1 147 | features.append(OrgFeat.IsSamePara(s1_edus, q1_edus)) 148 | features.append(OrgFeat.IsContinuePara(s1_edus, q1_edus)) 149 | 150 | # for Stack 1, 2 and Queue 1 151 | features.append( 152 | OrgFeat.IsSamePara(s2_edus, s1_edus) & OrgFeat.IsSamePara(s1_edus, q1_edus) 153 | ) 154 | 155 | # starts and ends a paragraph 156 | features.append(OrgFeat.IsStartPara(s1_edus)) 157 | features.append(OrgFeat.IsStartPara(s2_edus)) 158 | features.append(OrgFeat.IsStartPara(q1_edus)) 159 | features.append(OrgFeat.IsEndPara(s1_edus)) 160 | features.append(OrgFeat.IsEndPara(s2_edus)) 161 | features.append(OrgFeat.IsEndPara(q1_edus)) 162 | 163 | # convert to index 164 | bias = torch.tensor([2 * i for i in range(len(features))], dtype=torch.long, device=device) 165 | features = torch.tensor(features, dtype=torch.long, device=device) 166 | return bias + features 167 | -------------------------------------------------------------------------------- /src/models/parser/shift_reduce_parser_v1.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Optional 3 | 4 | import torch 5 | 6 | from data.dataset import Dataset 7 | from data.doc import Doc 8 | from data.tree import RSTTree 9 | from models.classifier import ShiftReduceClassifierBase 10 | from models.parser import ShiftReduceParserBase 11 | from models.parser.shift_reduce_state import ShiftReduceState 12 | from models.parser.utils import batch_iter 13 | 14 | 15 | class ShiftReduceParserV1(ShiftReduceParserBase): 16 | """This basic shift-reduce parser predicts action, nuc, and relation by independent 17 | classifiers.""" 18 | 19 | def __init__(self, classifier: ShiftReduceClassifierBase): 20 | super(ShiftReduceParserV1, self).__init__(classifier) 21 | assert isinstance(classifier, ShiftReduceClassifierBase) 22 | 23 | def generate_training_samples( 24 | self, 25 | dataset: Dataset, 26 | unit_type: str, 27 | batch_size: Optional[int] = None, 28 | ): 29 | action_vocab = dataset.action_vocab 30 | nucleus_vocab = dataset.nucleus_vocab 31 | relation_vocab = dataset.relation_vocab 32 | 33 | samples = [] 34 | for doc in dataset: 35 | tree = doc.tree 36 | if isinstance(tree, RSTTree): 37 | tree = RSTTree.convert_to_attach(tree) 38 | 39 | act_list, nuc_list, rel_list = self.generate_action_sequence(tree) 40 | xs, ys, fs = [], [], [] 41 | state = ShiftReduceState(len(tree.leaves())) 42 | for act, nuc, rel in zip(act_list, nuc_list, rel_list): 43 | s1, s2, q1 = state.get_state() 44 | act_idx = action_vocab[act] 45 | nuc_idx = nucleus_vocab[nuc] 46 | rel_idx = relation_vocab[rel] 47 | org_feat = self.get_organization_features(s1, s2, q1, doc) 48 | xs.append({"s1": s1, "s2": s2, "q1": q1}) 49 | ys.append({"act": act_idx, "nuc": nuc_idx, "rel": rel_idx}) 50 | fs.append({"org": org_feat}) 51 | state.operate(act, nuc, rel) 52 | 53 | assert tree == state.get_tree() 54 | 55 | if unit_type == "document": 56 | samples.append({"doc": doc, "span": xs, "label": ys, "feat": fs}) 57 | elif unit_type == "span": 58 | for x, y, f in zip(xs, ys, fs): 59 | samples.append({"doc": doc, "span": x, "label": y, "feat": f}) 60 | elif unit_type == "span_fast": 61 | assert batch_size > 1 62 | # should use Trainer.reload_dataloaders_every_n_epochs=1 63 | indices = list(range(len(xs))) 64 | random.shuffle(indices) 65 | xs = [xs[i] for i in indices] 66 | ys = [ys[i] for i in indices] 67 | fs = [fs[i] for i in indices] 68 | for feats in batch_iter(list(zip(xs, ys, fs)), batch_size): 69 | b_xs, b_ys, b_fs = list(zip(*feats)) 70 | samples.append({"doc": doc, "span": b_xs, "label": b_ys, "feat": b_fs}) 71 | else: 72 | raise ValueError("Invalid batch unit_type ({})".format(unit_type)) 73 | 74 | return samples 75 | 76 | def select_action_and_labels(self, bert_output, span, feat, state, gold_act=None): 77 | act_vocab = self.classifier.act_vocab 78 | nuc_vocab = self.classifier.nuc_vocab 79 | rel_vocab = self.classifier.rel_vocab 80 | 81 | act_scores, nuc_scores, rel_scores = self.classifier.predict(bert_output, span, feat) 82 | 83 | # select allowed action 84 | _, act_indices = torch.sort(act_scores, dim=0, descending=True) 85 | for act_idx in act_indices: 86 | act = act_vocab.lookup_token(act_idx) 87 | if act == "": 88 | continue 89 | 90 | if state.is_allowed_action(act): 91 | break 92 | 93 | # use gold_act if gold_act is given 94 | if gold_act is not None: 95 | act = gold_act 96 | 97 | nuc, rel = None, None 98 | if act != "shift": 99 | nuc_scores[nuc_vocab[""]] = -float("inf") 100 | nuc = nuc_vocab.lookup_token(torch.argmax(nuc_scores)) 101 | rel_scores[rel_vocab[""]] = -float("inf") 102 | rel = rel_vocab.lookup_token(torch.argmax(rel_scores)) 103 | 104 | return act, nuc, rel 105 | 106 | def BEAM(self, doc: Doc, top_k: int): 107 | act_vocab = self.classifier.act_vocab 108 | nuc_vocab = self.classifier.nuc_vocab 109 | rel_vocab = self.classifier.rel_vocab 110 | 111 | bert_output = self.classifier.encoder(doc) 112 | n_edus = len(doc.edus) 113 | 114 | num_steps = 2 * n_edus - 1 115 | beams = [None] * (num_steps + 1) 116 | beams[0] = [ShiftReduceState(n_edus)] # initial state 117 | 118 | for i in range(num_steps): 119 | buf = [] 120 | for old_state in beams[i]: 121 | s1, s2, q1 = old_state.get_state() 122 | span = {"s1": s1, "s2": s2, "q1": q1} 123 | feat = { 124 | "org": self.get_organization_features(s1, s2, q1, doc, self.classifier.device) 125 | } 126 | act_scores, nuc_scores, rel_scores = self.classifier.predict( 127 | bert_output, span, feat 128 | ) 129 | log_act_scores = act_scores.log_softmax(dim=0) 130 | 131 | for act in old_state.allowed_actions(): 132 | rel, nuc = None, None 133 | if act != "shift": 134 | nuc_scores[nuc_vocab[""]] = -float("inf") 135 | nuc = nuc_vocab.lookup_token(torch.argmax(nuc_scores)) 136 | rel_scores[rel_vocab[""]] = -float("inf") 137 | rel = rel_vocab.lookup_token(torch.argmax(rel_scores)) 138 | 139 | act_idx = act_vocab[act] 140 | action_score = log_act_scores[act_idx].item() 141 | new_state = old_state.copy() 142 | new_state.operate(act, nuc, rel, score=action_score) 143 | buf.append(new_state) 144 | 145 | buf = sorted(buf, key=lambda x: x.score, reverse=True) # descending order 146 | 147 | tmp = {} 148 | beams[i + 1] = [] 149 | for j, new_state in enumerate(buf): 150 | _state = new_state.get_state() 151 | score = new_state.score 152 | idx = (_state, score) 153 | if idx not in tmp: 154 | tmp[idx] = new_state 155 | beams[i + 1].append(new_state) 156 | else: 157 | # print('duplicate') 158 | # tmp[_state].merge_state(new_state) 159 | pass 160 | 161 | if len(tmp) == top_k: 162 | break 163 | 164 | trees = [x.get_tree() for x in beams[-1][:top_k]] 165 | return trees 166 | -------------------------------------------------------------------------------- /src/models/parser/shift_reduce_parser_v2.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Optional 3 | 4 | import torch 5 | 6 | from data.dataset import Dataset 7 | from data.doc import Doc 8 | from data.tree import RSTTree 9 | from models.classifier import ShiftReduceClassifierBase 10 | from models.parser import ShiftReduceParserBase 11 | from models.parser.shift_reduce_state import ShiftReduceState 12 | from models.parser.utils import batch_iter 13 | 14 | 15 | class ShiftReduceParserV2(ShiftReduceParserBase): 16 | """This shift-reduce parser predicts nuc and relation labels by one classifier.""" 17 | 18 | def __init__(self, classifier: ShiftReduceClassifierBase): 19 | super(ShiftReduceParserV2, self).__init__(classifier) 20 | assert isinstance(classifier, ShiftReduceClassifierBase) 21 | 22 | def generate_training_samples( 23 | self, 24 | dataset: Dataset, 25 | unit_type: str, 26 | batch_size: Optional[int] = None, 27 | ): 28 | act_vocab = self.classifier.act_vocab 29 | ful_vocab = self.classifier.ful_vocab 30 | 31 | samples = [] 32 | for doc in dataset: 33 | tree = doc.tree 34 | if isinstance(tree, RSTTree): 35 | tree = RSTTree.convert_to_attach(tree) 36 | 37 | act_list, nuc_list, rel_list = self.generate_action_sequence(tree) 38 | xs, ys, fs = [], [], [] 39 | state = ShiftReduceState(len(tree.leaves())) 40 | for act, nuc, rel in zip(act_list, nuc_list, rel_list): 41 | s1, s2, q1 = state.get_state() 42 | label = "" if nuc == rel == "" else ":".join([nuc, rel]) 43 | act_idx = act_vocab[act] 44 | ful_idx = ful_vocab[label] 45 | org_feat = self.get_organization_features(s1, s2, q1, doc) 46 | xs.append({"s1": s1, "s2": s2, "q1": q1}) 47 | ys.append({"act": act_idx, "ful": ful_idx}) 48 | fs.append({"org": org_feat}) 49 | state.operate(act, nuc, rel) 50 | 51 | assert tree == state.get_tree() 52 | 53 | if unit_type == "document": 54 | samples.append({"doc": doc, "span": xs, "label": ys, "feat": fs}) 55 | elif unit_type == "span": 56 | for x, y, f in zip(xs, ys, fs): 57 | samples.append({"doc": doc, "span": x, "label": y, "feat": f}) 58 | elif unit_type == "span_fast": 59 | assert batch_size > 1 60 | # should use Trainer.reload_dataloaders_every_n_epochs=1 61 | indices = list(range(len(xs))) 62 | random.shuffle(indices) 63 | xs = [xs[i] for i in indices] 64 | ys = [ys[i] for i in indices] 65 | fs = [fs[i] for i in indices] 66 | for feats in batch_iter(list(zip(xs, ys, fs)), batch_size): 67 | b_xs, b_ys, b_fs = list(zip(*feats)) 68 | samples.append({"doc": doc, "span": b_xs, "label": b_ys, "feat": b_fs}) 69 | else: 70 | raise ValueError("Invalid batch unit_type ({})".format(unit_type)) 71 | 72 | return samples 73 | 74 | def select_action_and_labels(self, bert_output, span, feat, state, gold_act=None): 75 | act_vocab = self.classifier.act_vocab 76 | ful_vocab = self.classifier.ful_vocab 77 | 78 | act_scores, ful_scores = self.classifier.predict(bert_output, span, feat) 79 | 80 | # select allowed action 81 | _, act_indices = torch.sort(act_scores, dim=0, descending=True) 82 | for act_idx in act_indices: 83 | act = act_vocab.lookup_token(act_idx) 84 | if act == "": 85 | continue 86 | 87 | if gold_act is not None: 88 | if act == gold_act: 89 | break 90 | 91 | else: 92 | if state.is_allowed_action(act): 93 | break 94 | 95 | rel, nuc = None, None 96 | if act != "shift": 97 | ful_scores[ful_vocab[""]] = -float("inf") 98 | label = ful_vocab.lookup_token(torch.argmax(ful_scores)) 99 | nuc, rel = label.split(":", maxsplit=1) 100 | 101 | return act, nuc, rel 102 | 103 | def parse_topk(self, doc: Doc, k: int): 104 | return self.BEAM(doc, k) 105 | 106 | def BEAM(self, doc: Doc, top_k: int): 107 | act_vocab = self.classifier.act_vocab 108 | ful_vocab = self.classifier.ful_vocab 109 | 110 | bert_output = self.classifier.encoder(doc) 111 | n_edus = len(doc.edus) 112 | 113 | num_steps = 2 * n_edus - 1 114 | beams = [None] * (num_steps + 1) 115 | beams[0] = [ShiftReduceState(n_edus)] # initial state 116 | 117 | for i in range(num_steps): 118 | buf = [] 119 | for old_state in beams[i]: 120 | s1, s2, q1 = old_state.get_state() 121 | span = {"s1": s1, "s2": s2, "q1": q1} 122 | feat = { 123 | "org": self.get_organization_features(s1, s2, q1, doc, self.classifier.device) 124 | } 125 | act_scores, ful_scores = self.classifier.predict(bert_output, span, feat) 126 | log_act_scores = act_scores.log_softmax(dim=0) 127 | 128 | for act in old_state.allowed_actions(): 129 | rel, nuc = None, None 130 | if act != "shift": 131 | ful_scores[ful_vocab[""]] = -float("inf") 132 | label = ful_vocab.lookup_token(torch.argmax(ful_scores)) 133 | nuc, rel = label.split(":", maxsplit=1) 134 | 135 | act_idx = act_vocab[act] 136 | action_score = log_act_scores[act_idx].item() 137 | new_state = old_state.copy() 138 | new_state.operate(act, nuc, rel, score=action_score) 139 | buf.append(new_state) 140 | 141 | buf = sorted(buf, key=lambda x: x.score, reverse=True) # descending order 142 | 143 | tmp = {} 144 | beams[i + 1] = [] 145 | for j, new_state in enumerate(buf): 146 | _state = new_state.get_state() 147 | score = new_state.score 148 | idx = (_state, score) 149 | if idx not in tmp: 150 | tmp[idx] = new_state 151 | beams[i + 1].append(new_state) 152 | else: 153 | # print('duplicate') 154 | # tmp[_state].merge_state(new_state) 155 | pass 156 | 157 | if len(tmp) == top_k: 158 | break 159 | 160 | trees = [x.get_tree() for x in beams[-1][:top_k]] 161 | return trees 162 | -------------------------------------------------------------------------------- /src/models/parser/shift_reduce_parser_v3.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Optional 3 | 4 | import torch 5 | 6 | from data.dataset import Dataset 7 | from data.doc import Doc 8 | from data.tree import RSTTree 9 | from models.classifier import ShiftReduceClassifierBase 10 | from models.parser import ShiftReduceParserBase 11 | from models.parser.shift_reduce_state import ShiftReduceState 12 | from models.parser.utils import batch_iter 13 | 14 | 15 | class ShiftReduceParserV3(ShiftReduceParserBase): 16 | """This shift-reduce parser predicts action and nuc label by one classifier.""" 17 | 18 | def __init__(self, classifier: ShiftReduceClassifierBase): 19 | super(ShiftReduceParserV3, self).__init__(classifier) 20 | assert isinstance(classifier, ShiftReduceClassifierBase) 21 | 22 | def generate_training_samples( 23 | self, 24 | dataset: Dataset, 25 | unit_type: str, 26 | batch_size: Optional[int] = None, 27 | ): 28 | act_nuc_vocab = dataset.act_nuc_vocab 29 | relation_vocab = dataset.relation_vocab 30 | 31 | samples = [] 32 | for doc in dataset: 33 | tree = doc.tree 34 | if isinstance(tree, RSTTree): 35 | tree = RSTTree.convert_to_attach(tree) 36 | 37 | act_list, nuc_list, rel_list = self.generate_action_sequence(tree) 38 | xs, ys, fs = [], [], [] 39 | state = ShiftReduceState(len(tree.leaves())) 40 | for act, nuc, rel in zip(act_list, nuc_list, rel_list): 41 | s1, s2, q1 = state.get_state() 42 | act_nuc = "_".join([act, nuc]) 43 | act_nuc_idx = act_nuc_vocab[act_nuc] 44 | rel_idx = relation_vocab[rel] 45 | org_feat = self.get_organization_features(s1, s2, q1, doc) 46 | xs.append({"s1": s1, "s2": s2, "q1": q1}) 47 | ys.append({"act_nuc": act_nuc_idx, "rel": rel_idx}) 48 | fs.append({"org": org_feat}) 49 | state.operate(act, nuc, rel) 50 | 51 | assert tree == state.get_tree() 52 | 53 | if unit_type == "document": 54 | samples.append({"doc": doc, "span": xs, "label": ys, "feat": fs}) 55 | elif unit_type == "span": 56 | for x, y, f in zip(xs, ys, fs): 57 | samples.append({"doc": doc, "span": x, "label": y, "feat": f}) 58 | elif unit_type == "span_fast": 59 | assert batch_size > 1 60 | # should use Trainer.reload_dataloaders_every_n_epochs=1 61 | indices = list(range(len(xs))) 62 | random.shuffle(indices) 63 | xs = [xs[i] for i in indices] 64 | ys = [ys[i] for i in indices] 65 | fs = [fs[i] for i in indices] 66 | for feats in batch_iter(list(zip(xs, ys, fs)), batch_size): 67 | b_xs, b_ys, b_fs = list(zip(*feats)) 68 | samples.append({"doc": doc, "span": b_xs, "label": b_ys, "feat": b_fs}) 69 | else: 70 | raise ValueError("Invalid batch unit_type ({})".format(unit_type)) 71 | 72 | return samples 73 | 74 | def select_action_and_labels(self, bert_output, span, feat, state, gold_act=None): 75 | act_nuc_vocab = self.classifier.act_nuc_vocab 76 | rel_vocab = self.classifier.rel_vocab 77 | 78 | act_nuc_scores, rel_scores = self.classifier.predict(bert_output, span, feat) 79 | 80 | # select allowed action 81 | _, act_nuc_indices = torch.sort(act_nuc_scores, dim=0, descending=True) 82 | for act_nuc_idx in act_nuc_indices: 83 | act_nuc = act_nuc_vocab.lookup_token(act_nuc_idx) 84 | if act_nuc == "": 85 | continue 86 | 87 | act, nuc = act_nuc.split("_") 88 | if gold_act is not None: 89 | if act == gold_act: 90 | break 91 | 92 | else: 93 | if state.is_allowed_action(act): 94 | break 95 | 96 | rel = None 97 | if act != "shift": 98 | rel_scores[rel_vocab[""]] = -float("inf") 99 | rel = rel_vocab.lookup_token(torch.argmax(rel_scores)) 100 | 101 | return act, nuc, rel 102 | 103 | def parse_topk(self, doc: Doc, k: int): 104 | return self.BEAM(doc, k) 105 | 106 | def BEAM(self, doc: Doc, top_k: int): 107 | act_nuc_vocab = self.classifier.act_nuc_vocab 108 | rel_vocab = self.classifier.rel_vocab 109 | 110 | bert_output = self.classifier.encoder(doc) 111 | n_edus = len(doc.edus) 112 | 113 | num_steps = 2 * n_edus - 1 114 | beams = [None] * (num_steps + 1) 115 | beams[0] = [ShiftReduceState(n_edus)] # initial state 116 | 117 | for i in range(num_steps): 118 | buf = [] 119 | for old_state in beams[i]: 120 | s1, s2, q1 = old_state.get_state() 121 | span = {"s1": s1, "s2": s2, "q1": q1} 122 | feat = { 123 | "org": self.get_organization_features(s1, s2, q1, doc, self.classifier.device) 124 | } 125 | act_nuc_scores, rel_scores = self.classifier.predict(bert_output, span, feat) 126 | log_act_scores = self.classifier.act_nuc_to_act_scores(act_nuc_scores) 127 | 128 | for act in old_state.allowed_actions(): 129 | rel, nuc = None, None 130 | if act != "shift": 131 | act_nuc_scores[act_nuc_vocab["shift_"]] = -float("inf") 132 | act_nuc_scores[act_nuc_vocab[""]] = -float("inf") 133 | act_nuc = act_nuc_vocab.lookup_token(torch.argmax(act_nuc_scores)) 134 | _, nuc = act_nuc.split("_") 135 | rel_scores[rel_vocab[""]] = -float("inf") 136 | rel = rel_vocab.lookup_token(torch.argmax(rel_scores)) 137 | 138 | action_score = log_act_scores[act].item() 139 | new_state = old_state.copy() 140 | new_state.operate(act, nuc, rel, score=action_score) 141 | buf.append(new_state) 142 | 143 | buf = sorted(buf, key=lambda x: x.score, reverse=True) # descending order 144 | 145 | tmp = {} 146 | beams[i + 1] = [] 147 | for j, new_state in enumerate(buf): 148 | _state = new_state.get_state() 149 | score = new_state.score 150 | idx = (_state, score) 151 | if idx not in tmp: 152 | tmp[idx] = new_state 153 | beams[i + 1].append(new_state) 154 | else: 155 | # print('duplicate') 156 | # tmp[_state].merge_state(new_state) 157 | pass 158 | 159 | if len(tmp) == top_k: 160 | break 161 | 162 | trees = [x.get_tree() for x in beams[-1][:top_k]] 163 | return trees 164 | -------------------------------------------------------------------------------- /src/models/parser/shift_reduce_state.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union 2 | 3 | from data.edu import EDU 4 | from data.tree import AttachTree 5 | 6 | 7 | class ShiftReduceState(object): 8 | # TODO: 9 | # https://github.com/lianghuang3/lineardpparser/blob/master/code/newstate.py 10 | def __init__(self, n_edus: int): 11 | self.n_edus = n_edus 12 | self.stack = [] 13 | self.queue = list(map(str, range(n_edus)[::-1])) 14 | self.score = 0 15 | 16 | def copy(self): 17 | # make new object 18 | x = ShiftReduceState(self.n_edus) 19 | # copy params 20 | x.stack = self.stack.copy() 21 | x.queue = self.queue.copy() 22 | x.score = self.score 23 | return x 24 | 25 | def operate(self, action: str, nuc: str, rel: str, score: float = 0): 26 | self.score = self.score + score 27 | 28 | if action == "shift": 29 | edu_idx = self.queue.pop() 30 | node = AttachTree("text", [edu_idx]) 31 | self.stack.append(node) 32 | elif action == "reduce": 33 | r_node = self.stack.pop() 34 | l_node = self.stack.pop() 35 | label = ":".join([nuc, rel]) 36 | new_node = AttachTree(label, [l_node, r_node]) 37 | self.stack.append(new_node) 38 | else: 39 | raise ValueError("unexpected action: {}".format(action)) 40 | 41 | def is_end(self): 42 | return len(self.stack) == 1 and len(self.queue) == 0 43 | 44 | def get_tree(self): 45 | if self.is_end(): 46 | return self.stack[0] 47 | else: 48 | raise ValueError 49 | 50 | def get_state(self): 51 | def get_edu_span(node: Union[AttachTree, str]): 52 | if isinstance(node, AttachTree): 53 | leaves = node.leaves() 54 | span = (int(leaves[0]), int(leaves[-1]) + 1) 55 | else: 56 | edu_idx = node 57 | span = (int(edu_idx), int(edu_idx) + 1) 58 | 59 | return span 60 | 61 | # stack top1, top2 62 | s1 = get_edu_span(self.stack[-1]) if len(self.stack) > 0 else (-1, -1) 63 | s2 = get_edu_span(self.stack[-2]) if len(self.stack) > 1 else (-1, -1) 64 | # queue first 65 | q1 = get_edu_span(self.queue[-1]) if len(self.queue) > 0 else (-1, -1) 66 | return s1, s2, q1 67 | 68 | def allowed_actions(self): 69 | actions = [] 70 | for act in ["shift", "reduce"]: 71 | if self.is_allowed_action(act): 72 | actions.append(act) 73 | 74 | return actions 75 | 76 | def is_allowed_action(self, action: str): 77 | if action == "shift": 78 | return len(self.queue) >= 1 79 | elif action == "reduce": 80 | return len(self.stack) >= 2 81 | else: 82 | raise ValueError 83 | 84 | def is_allowed_action_for_sentence_level_parse(self, action: str, edus: List[EDU]): 85 | s1, s2, q1 = self.get_state() 86 | if action == "shift": 87 | if len(self.queue) < 1: 88 | return False 89 | 90 | if len(self.stack) < 1: 91 | return True 92 | 93 | # shiftできない場合: 94 | # - s1とq1が異なる文に属しており, 95 | # s1が不完全な文ノードであるならshiftできない. 96 | # * 判定のためにはs1, q1が存在する必要がある. 97 | if not self.in_same_sentence(s1, q1, edus): 98 | if self.is_part_of_sentence_tree(s1, edus): 99 | return False 100 | 101 | return True 102 | 103 | elif action == "reduce": 104 | if len(self.stack) < 2: 105 | return False 106 | 107 | if len(self.queue) < 1: 108 | return True 109 | 110 | # reduceできない場合: 111 | # - s1とq1が同じ文に属しており, 112 | # s1とs2が異なる文に属するならreduceできない. 113 | # * 判定のためにはs1,s2,q1が存在する必要がある. 114 | if self.in_same_sentence(s1, q1, edus): 115 | if not self.in_same_sentence(s1, s2, edus): 116 | return False 117 | 118 | return True 119 | else: 120 | raise ValueError('Invalid action "{}"'.format(action)) 121 | 122 | def in_same_sentence(self, q1: Tuple[int], s1: Tuple[int], edus: List[EDU]): 123 | if q1 == (-1, -1) or s1 == (-1, -1): 124 | raise ValueError 125 | # q1 is edu idx 126 | # s1 is edu span 127 | q1_edu = edus[q1[0]] 128 | s1_edu = edus[s1[1] - 1] # most right edu of s1 129 | return q1_edu.sent_idx == s1_edu.sent_idx 130 | 131 | def is_part_of_sentence_tree(self, s: Tuple[int], edus: List[EDU]): 132 | if s == (-1, -1): 133 | raise ValueError 134 | # s is edu span 135 | s_left_edu = edus[s[0]] # most left edu of s 136 | s_right_edu = edus[s[1] - 1] # most right edu of s 137 | if s_left_edu.sent_idx != s_right_edu.sent_idx: 138 | # larger than sentence 139 | return False 140 | if s_left_edu.start_sent and s_right_edu.end_sent: 141 | # complete sentence 142 | return False 143 | 144 | # smaller than sentence 145 | return True 146 | -------------------------------------------------------------------------------- /src/models/parser/top_down_parser_base.py: -------------------------------------------------------------------------------- 1 | from heapq import heappop, heappush 2 | from typing import Tuple 3 | 4 | import torch 5 | 6 | from data.doc import Doc 7 | from data.tree import AttachTree 8 | from models.classifier import TopDownClassifierBase 9 | from models.parser import ParserBase 10 | from models.parser.organization_feature import OrganizationFeature as OrgFeat 11 | 12 | 13 | class TopDownParserBase(ParserBase): 14 | """base class for top-down parser.""" 15 | 16 | def __init__(self, classifier: TopDownClassifierBase): 17 | super(TopDownParserBase, self).__init__(classifier) 18 | assert isinstance(classifier, TopDownClassifierBase) 19 | 20 | def parse(self, doc: Doc): 21 | bert_output = self.classifier.encoder(doc) 22 | n_edus = len(doc.edus) 23 | span = (0, n_edus) 24 | 25 | def build_tree(span): 26 | if span[0] + 1 == span[1]: 27 | edu_idx = str(int(span[0])) 28 | return AttachTree("text", [edu_idx]) 29 | 30 | k = self.classifier.predict_split(bert_output, span, doc, return_scores=False) 31 | # k = torch.argmax(split_scores) + span[0] + 1 32 | # k = k.item() 33 | label = self.classifier.predict_label( 34 | bert_output, (span[0], k), (k, span[1]), doc, return_scores=False 35 | ) 36 | 37 | return AttachTree(label, [build_tree((span[0], k)), build_tree((k, span[1]))]) 38 | 39 | tree = build_tree(span) 40 | return tree 41 | 42 | def parse_topk(self, doc: Doc, k: int): 43 | naked_trees = self.CKY(doc, k) 44 | labeled_trees = [self.labeling_to_naked_tree(doc, tree) for tree in naked_trees] 45 | return labeled_trees 46 | 47 | def CKY(self, doc: Doc, top_k: int): 48 | bert_output = self.classifier.encoder(doc) 49 | n_edus = len(doc.edus) 50 | 51 | # prepare org_indices for all spans 52 | span_to_org_indices = {} 53 | for i in range(n_edus): 54 | for j in range(i, n_edus + 1): 55 | span = (i, j) 56 | org_indices = [] 57 | for k in range(i + 1, j): 58 | org_indices.append(self.get_organization_features((i, k), (k, j), doc)) 59 | 60 | span_to_org_indices[span] = org_indices 61 | 62 | # init tables 63 | CKY_table = [[[-1 for k in range(top_k)] for j in range(n_edus)] for i in range(n_edus)] 64 | count_table = [[-1 for j in range(n_edus)] for i in range(n_edus)] 65 | trace_table = [ 66 | [[(-1, -1, -1) for k in range(top_k)] for j in range(n_edus)] for i in range(n_edus) 67 | ] 68 | 69 | def mult(log_split_scores, left_edu_idx, right_edu_idx): 70 | buf = [] 71 | for c_idx, k in enumerate(range(left_edu_idx, right_edu_idx)): 72 | left_scores = CKY_table[left_edu_idx][k][: count_table[left_edu_idx][k]] 73 | right_scores = CKY_table[k + 1][right_edu_idx][: count_table[k + 1][right_edu_idx]] 74 | split_score = log_split_scores[k - left_edu_idx].item() 75 | for l, left_score in enumerate(left_scores): 76 | for r, right_score in enumerate(right_scores): 77 | score = left_score + right_score + split_score 78 | buf.append( 79 | { 80 | "score": score, 81 | "trace": (l, k + 1, r), 82 | } 83 | ) 84 | 85 | sorted_buf = sorted(buf, key=lambda x: x["score"], reverse=True) 86 | top_k_buf = sorted_buf[:top_k] 87 | return top_k_buf 88 | 89 | def mult_with_heap(split_scores, left_edu_idx, right_edu_idx): 90 | i = left_edu_idx 91 | j = right_edu_idx 92 | heap = [] 93 | for k in range(i, j): 94 | # 0. starts with an initial heap of the 1-best derivations 95 | left_scores = CKY_table[i][k] # corresponding to (i, k) 96 | right_scores = CKY_table[k + 1][j] # corresponding to (k, j) 97 | split_score = log_split_scores[k - i].item() 98 | l, r = 0, 0 99 | score = left_scores[l] + right_scores[r] + split_score 100 | item = (-score, (k, l, r)) 101 | heappush(heap, item) 102 | 103 | top_k_scores = [] 104 | heap_elm_dict = {} # remenber heap items to avoid duplication 105 | for _ in range(top_k): 106 | if heap == []: 107 | # no element in the heap 108 | break 109 | 110 | # 1. extract-max from the heap 111 | elm = heappop(heap) 112 | best_score = elm[0] 113 | indices = elm[1] # (k, l, r) 114 | k = indices[0] 115 | l = indices[1] 116 | r = indices[2] 117 | top_k_scores.append( 118 | { 119 | "score": -best_score, 120 | "trace": ( 121 | l, 122 | k + 1, 123 | r, 124 | ), # k+1 edu index corresponding to split point 125 | } 126 | ) 127 | 128 | left_scores = CKY_table[i][k] 129 | right_scores = CKY_table[k + 1][j] 130 | split_score = log_split_scores[k - i].item() 131 | 132 | # 2. push the two "shoulders" into the heap 133 | if l + 1 < count_table[i][k]: 134 | score = left_scores[l + 1] + right_scores[r] + split_score 135 | item = (-score, (k, l + 1, r)) 136 | if item not in heap_elm_dict: 137 | heappush(heap, item) 138 | heap_elm_dict[item] = 0 139 | 140 | if r + 1 < count_table[k + 1][j]: 141 | score = left_scores[l] + right_scores[r + 1] + split_score 142 | item = (-score, (k, l, r + 1)) 143 | if item not in heap_elm_dict: 144 | heappush(heap, item) 145 | heap_elm_dict[item] = 0 146 | 147 | return top_k_scores 148 | 149 | # fill tables with DP 150 | for length in range(n_edus): 151 | for offset in range(n_edus): 152 | left_edu_idx, right_edu_idx = offset, offset + length 153 | if right_edu_idx > n_edus - 1: 154 | # out of range 155 | break 156 | 157 | if length == 0: 158 | # init leaf node with log(0.0) == 1.0 159 | CKY_table[left_edu_idx][right_edu_idx][0] = torch.tensor(1.0).log().item() 160 | count_table[left_edu_idx][right_edu_idx] = 1 161 | continue 162 | 163 | span = (left_edu_idx, right_edu_idx + 1) 164 | split_scores = self.classifier.predict_split_fast( 165 | bert_output, span, doc, span_to_org_indices, return_scores=True 166 | ) 167 | log_split_scores = split_scores.log_softmax(dim=0) 168 | 169 | # top_k_buf = mult(log_split_scores, left_edu_idx, right_edu_idx) 170 | top_k_buf = mult_with_heap(log_split_scores, left_edu_idx, right_edu_idx) 171 | scores = [(x["score"]) for x in top_k_buf] 172 | CKY_table[left_edu_idx][right_edu_idx][: len(top_k_buf)] = scores 173 | traces = [(x["trace"]) for x in top_k_buf] 174 | trace_table[left_edu_idx][right_edu_idx][: len(top_k_buf)] = traces 175 | count_table[left_edu_idx][right_edu_idx] = len(top_k_buf) 176 | 177 | def backtrace(table, left, right, index): 178 | assert right - left >= 1 179 | label = "nucleus-satellite:Elaboration" # majority label 180 | if right - left == 1: 181 | edu_index = left 182 | return AttachTree("text", [str(edu_index)]) 183 | 184 | l, k, r = table[left][right - 1][index] 185 | return AttachTree(label, [backtrace(table, left, k, l), backtrace(table, k, right, r)]) 186 | 187 | # build tree by back-tracing 188 | trees = [ 189 | backtrace(trace_table, 0, n_edus, index) 190 | for index in range(top_k) 191 | if trace_table[0][n_edus - 1][index] != (-1, -1, -1) 192 | ] 193 | return trees 194 | 195 | def parse_with_naked_tree(self, doc: Doc, naked_tree: AttachTree): 196 | return self.labeling_to_naked_tree(doc, naked_tree) 197 | 198 | def labeling_to_naked_tree(self, doc, naked_tree): 199 | bert_output = self.classifier.encoder(doc) 200 | n_edus = len(doc.edus) 201 | span = (0, n_edus) 202 | 203 | def build_tree(span, given_tree): 204 | if span[0] + 1 == span[1]: 205 | edu_idx = str(int(span[0])) 206 | return AttachTree("text", [edu_idx]) 207 | 208 | # get a split point from given tree 209 | _, leaves = given_tree.label(), given_tree.leaves() 210 | given_span = (int(leaves[0]), int(leaves[-1]) + 1) 211 | assert given_span == span 212 | given_split_idx = int(given_tree[1].leaves()[0]) - span[0] - 1 213 | k = given_split_idx + span[0] + 1 # == given_tree[1].leves()[0] 214 | 215 | # predict labels 216 | label = self.classifier.predict_label( 217 | bert_output, (span[0], k), (k, span[1]), doc, return_scores=False 218 | ) 219 | 220 | return AttachTree( 221 | label, 222 | [ 223 | build_tree((span[0], k), given_tree[0]), 224 | build_tree((k, span[1]), given_tree[1]), 225 | ], 226 | ) 227 | 228 | labeled_tree = build_tree(span, naked_tree) 229 | return labeled_tree 230 | 231 | def get_organization_features( 232 | self, left_span: Tuple[int], right_span: Tuple[int], doc: Doc, device=None 233 | ): 234 | edus = doc.edus 235 | left_edus = edus[slice(*left_span)] 236 | right_edus = edus[slice(*right_span)] 237 | 238 | features = [] 239 | 240 | if not self.classifier.disable_org_sent: 241 | # same or continue sentence 242 | features.append(OrgFeat.IsSameSent(left_edus, right_edus)) 243 | features.append(OrgFeat.IsContinueSent(left_edus, right_edus)) 244 | 245 | # starts and ends a sentence 246 | features.append(OrgFeat.IsStartSent(left_edus)) 247 | features.append(OrgFeat.IsStartSent(right_edus)) 248 | features.append(OrgFeat.IsEndSent(left_edus)) 249 | features.append(OrgFeat.IsEndSent(right_edus)) 250 | features.append(OrgFeat.IsStartDoc(left_edus)) 251 | features.append(OrgFeat.IsStartDoc(right_edus)) 252 | features.append(OrgFeat.IsEndDoc(left_edus)) 253 | features.append(OrgFeat.IsEndDoc(right_edus)) 254 | 255 | if not self.classifier.disable_org_para: 256 | # same or continue sentence 257 | features.append(OrgFeat.IsSamePara(left_edus, right_edus)) 258 | features.append(OrgFeat.IsContinuePara(left_edus, right_edus)) 259 | 260 | # starts and ends a paragraph 261 | features.append(OrgFeat.IsStartPara(left_edus)) 262 | features.append(OrgFeat.IsStartPara(right_edus)) 263 | features.append(OrgFeat.IsEndPara(left_edus)) 264 | features.append(OrgFeat.IsEndPara(right_edus)) 265 | 266 | # convert to index 267 | bias = torch.tensor([2 * i for i in range(len(features))], dtype=torch.long, device=device) 268 | features = torch.tensor(features, dtype=torch.long, device=device) 269 | return bias + features 270 | -------------------------------------------------------------------------------- /src/models/parser/top_down_parser_v1.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Optional 3 | 4 | from data.dataset import Dataset 5 | from data.tree import AttachTree, RSTTree 6 | from models.classifier import TopDownClassifierBase 7 | from models.parser import TopDownParserBase 8 | from models.parser.utils import batch_iter 9 | 10 | 11 | class TopDownParserV1(TopDownParserBase): 12 | """This basic top-down parser individually predicts span, nucleus and relation.""" 13 | 14 | def __init__(self, classifier: TopDownClassifierBase): 15 | super(TopDownParserV1, self).__init__(classifier) 16 | assert isinstance(classifier, TopDownClassifierBase) 17 | 18 | def generate_training_samples( 19 | self, 20 | dataset: Dataset, 21 | unit_type: str, 22 | batch_size: Optional[int] = None, 23 | ): 24 | nuc_vocab = self.classifier.nuc_vocab 25 | rel_vocab = self.classifier.rel_vocab 26 | 27 | samples = [] 28 | for doc in dataset: 29 | tree = doc.tree 30 | if isinstance(tree, RSTTree): 31 | tree = RSTTree.convert_to_attach(tree) 32 | 33 | xs, ys, fs = [], [], [] 34 | for tp in tree.treepositions(): 35 | node = tree[tp] 36 | if not isinstance(node, AttachTree): 37 | continue 38 | label = node.label() 39 | if label == "text": 40 | continue 41 | 42 | leaves = node.leaves() 43 | span = (int(leaves[0]), int(leaves[-1]) + 1) 44 | split_idx = int(node[1].leaves()[0]) 45 | 46 | nuc, rel = label.split(":", maxsplit=1) 47 | nuc_idx = nuc_vocab[nuc] 48 | rel_idx = rel_vocab[rel] 49 | 50 | org_feat = self.get_organization_features( 51 | (span[0], split_idx), (split_idx, span[1]), doc 52 | ) 53 | 54 | xs.append({"i": span[0], "j": span[1], "k": split_idx}) 55 | ys.append({"spl": split_idx - span[0] - 1, "nuc": nuc_idx, "rel": rel_idx}) 56 | fs.append({"org": org_feat}) 57 | 58 | if unit_type == "document": 59 | samples.append({"doc": doc, "span": xs, "label": ys, "feat": fs}) 60 | elif unit_type == "span": 61 | for x, y, f in zip(xs, ys, fs): 62 | samples.append({"doc": doc, "span": x, "label": y, "feat": f}) 63 | elif unit_type == "span_fast": 64 | assert batch_size > 1 65 | # should use Trainer.reload_dataloaders_every_n_epochs=1 66 | indices = list(range(len(xs))) 67 | random.shuffle(indices) 68 | xs = [xs[i] for i in indices] 69 | ys = [ys[i] for i in indices] 70 | fs = [fs[i] for i in indices] 71 | for feats in batch_iter(list(zip(xs, ys, fs)), batch_size): 72 | b_xs, b_ys, b_fs = list(zip(*feats)) 73 | samples.append({"doc": doc, "span": b_xs, "label": b_ys, "feat": b_fs}) 74 | else: 75 | raise ValueError("Invalid batch unit_type ({})".format(unit_type)) 76 | 77 | return samples 78 | -------------------------------------------------------------------------------- /src/models/parser/top_down_parser_v2.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Optional 3 | 4 | from data.dataset import Dataset 5 | from data.tree import AttachTree, RSTTree 6 | from models.classifier import TopDownClassifierBase 7 | from models.parser import TopDownParserBase 8 | from models.parser.utils import batch_iter 9 | 10 | 11 | class TopDownParserV2(TopDownParserBase): 12 | """This top-down parser predicts nucleus and relation labels by one classifier.""" 13 | 14 | def __init__(self, classifier: TopDownClassifierBase): 15 | super(TopDownParserV2, self).__init__(classifier) 16 | assert isinstance(classifier, TopDownClassifierBase) 17 | 18 | def generate_training_samples( 19 | self, 20 | dataset: Dataset, 21 | unit_type: str, 22 | batch_size: Optional[int] = None, 23 | ): 24 | ful_vocab = self.classifier.ful_vocab 25 | 26 | samples = [] 27 | for doc in dataset: 28 | tree = doc.tree 29 | if isinstance(tree, RSTTree): 30 | tree = RSTTree.convert_to_attach(tree) 31 | 32 | xs, ys, fs = [], [], [] 33 | for tp in tree.treepositions(): 34 | node = tree[tp] 35 | if not isinstance(node, AttachTree): 36 | continue 37 | label = node.label() 38 | if label == "text": 39 | continue 40 | 41 | leaves = node.leaves() 42 | span = (int(leaves[0]), int(leaves[-1]) + 1) 43 | split_idx = int(node[1].leaves()[0]) 44 | 45 | ful_idx = ful_vocab[label] 46 | org_feat = self.get_organization_features( 47 | (span[0], split_idx), (split_idx, span[1]), doc 48 | ) 49 | 50 | xs.append({"i": span[0], "j": span[1], "k": split_idx}) 51 | ys.append({"spl": split_idx - span[0] - 1, "ful": ful_idx}) 52 | fs.append({"org": org_feat}) 53 | 54 | if unit_type == "document": 55 | samples.append({"doc": doc, "span": xs, "label": ys, "feat": fs}) 56 | elif unit_type == "span": 57 | for x, y, f in zip(xs, ys, fs): 58 | samples.append({"doc": doc, "span": x, "label": y, "feat": f}) 59 | elif unit_type == "span_fast": 60 | assert batch_size > 1 61 | # should use Trainer.reload_dataloaders_every_n_epochs=1 62 | indices = list(range(len(xs))) 63 | random.shuffle(indices) 64 | xs = [xs[i] for i in indices] 65 | ys = [ys[i] for i in indices] 66 | fs = [fs[i] for i in indices] 67 | for feats in batch_iter(list(zip(xs, ys, fs)), batch_size): 68 | b_xs, b_ys, b_fs = list(zip(*feats)) 69 | samples.append({"doc": doc, "span": b_xs, "label": b_ys, "feat": b_fs}) 70 | else: 71 | raise ValueError("Invalid batch unit_type ({})".format(unit_type)) 72 | 73 | return samples 74 | -------------------------------------------------------------------------------- /src/models/parser/utils.py: -------------------------------------------------------------------------------- 1 | def batch_iter(iterable, batch_size: int = 1): 2 | l = len(iterable) 3 | for offset in range(0, l, batch_size): 4 | yield iterable[offset : min(offset + batch_size, l)] 5 | -------------------------------------------------------------------------------- /src/parse.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | from pathlib import Path 5 | from typing import Dict, List, Optional, Union 6 | import tempfile 7 | import warnings 8 | 9 | import torch 10 | from transformers import logging 11 | 12 | from data.datamodule import DataModule 13 | from data.tree import AttachTree 14 | from models.classifier import Classifiers 15 | from models.parser import Parsers 16 | 17 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 18 | warnings.filterwarnings("ignore") 19 | logging.set_verbosity_warning() 20 | logging.set_verbosity_error() 21 | 22 | 23 | def get_config(): 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument( 26 | "--ckpt-path", 27 | type=Path, 28 | help="checkpoint (.ckpt) file", 29 | ) 30 | parser.add_argument( 31 | "--save-dir", type=Path, required=True, help="path to output directory" 32 | ) 33 | 34 | # dataset 35 | parser.add_argument( 36 | "--corpus", 37 | default="RSTDT", 38 | choices=["RSTDT", "InstrDT"], 39 | help="corpus type (label set is in src/data/dataset.py)", 40 | ) 41 | parser.add_argument( 42 | "--documents", 43 | nargs="+", 44 | default=None, 45 | required=True, 46 | help="path to raw document file or directory", 47 | ) 48 | args = parser.parse_args() 49 | args.train_file = None 50 | args.valid_file = None 51 | args.test_file = None 52 | args.data_dir = None 53 | args.num_workers = 0 54 | args.batch_size = 1 55 | args.batch_unit_type = "document" 56 | return args 57 | 58 | 59 | def raw_docs2json(documents: List[Path]): 60 | raise NotImplementedError 61 | docs = [] 62 | for doc_file in documents: 63 | doc = { 64 | "rst_tree": None, 65 | "tokens": None, 66 | "edu_start_indices": None, 67 | "tokenized_edu_strings": None, 68 | "edu_starts_sentence": None, 69 | } 70 | with open(doc_file) as f: 71 | edus = [] 72 | for line in f: 73 | line = line.strip() 74 | if line: # not-empty 75 | edu = line 76 | docs.append(doc) 77 | 78 | return docs 79 | 80 | 81 | def raw_docs2dataset(documents: Union[List, Path], config: argparse.Namespace): 82 | # list up target documents 83 | doc_files = [] 84 | if isinstance(documents, Path): 85 | if documents.is_file(): # docuemnt file 86 | doc_file = documents 87 | doc_files = [doc_file] 88 | else: # directory contains document files 89 | doc_files = [doc_file for doc_file in documents.iterdir()] 90 | else: # list of docuemnt file 91 | doc_files = documents 92 | 93 | assert len(doc_files) > 0 94 | 95 | # convert to datamodule via json formatted dataset 96 | dataset = None 97 | with tempfile.NamedTemporaryFile() as fp: 98 | json.dump(raw_docs2json(doc_files), fp) 99 | 100 | config.data_dir = os.path.dirname(fp.name) 101 | config.test_file = os.path.basename(fp.name) 102 | dataset = DataModule.from_config(config, parser=None) 103 | 104 | return dataset 105 | 106 | 107 | def main(): 108 | config = get_config() 109 | device = torch.device("cuda:0") # hard codded 110 | ckpt_path = config.ckpt_path 111 | save_dir = config.save_dir 112 | dataset = raw_docs2dataset(config.documents, config) 113 | parse(ckpt_path, dataset, save_dir, device) 114 | print("trees for given docs are seved into {}".format(save_dir)) 115 | return 116 | 117 | 118 | def parse( 119 | ckpt_path: Union[Path, dict], 120 | dataset: DataModule, 121 | save_dir: Path, 122 | device: torch.device, 123 | ): 124 | # load params from checkpoint 125 | if isinstance(ckpt_path, Path): 126 | checkpoint = torch.load(ckpt_path) 127 | 128 | assert "state_dict" in checkpoint 129 | hparams = checkpoint["hyper_parameters"] 130 | 131 | # build classifier with pre-trained weights 132 | model_type = hparams["model_type"] 133 | classifier_class = Classifiers.classifier_dict[model_type] 134 | classifier = classifier_class.load_from_checkpoint(ckpt_path, map_location=device) 135 | 136 | # build parser 137 | parser_class = Parsers.parser_dict[model_type] 138 | parser = parser_class(classifier) 139 | 140 | classifier.eval() 141 | classifier.to(device) 142 | classifier.set_parser(parser) 143 | 144 | # build dataloader 145 | dataset.set_parser(parser) 146 | parse_set = dataset.test_dataloader() 147 | 148 | # parse 149 | with torch.no_grad(): 150 | output = parser.parse_dataset(parse_set) 151 | save_tree(output, save_dir) 152 | 153 | return 154 | 155 | 156 | def save_tree(output: Dict, save_dir: Optional[Path] = None): 157 | if save_dir is None: 158 | return 159 | 160 | save_dir.mkdir(parents=True, exist_ok=True) 161 | for doc_id, tree in zip(output["doc_id"], output["pred_tree"]): 162 | assert isinstance(tree, AttachTree) 163 | with open(save_dir / "{}.tree".format(doc_id), "w") as f: 164 | print(tree, file=f) 165 | 166 | return 167 | 168 | 169 | if __name__ == "__main__": 170 | main() 171 | -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | from pathlib import Path 5 | from typing import Dict, Optional, Union 6 | import warnings 7 | 8 | import torch 9 | from transformers import logging 10 | 11 | from average_ckpt import average_checkpoints 12 | from data.datamodule import DataModule 13 | from data.tree import AttachTree 14 | from metrics import RSTParseval, OriginalParseval 15 | from models.classifier import Classifiers 16 | from models.parser import Parsers 17 | 18 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 19 | warnings.filterwarnings("ignore") 20 | logging.set_verbosity_warning() 21 | logging.set_verbosity_error() 22 | 23 | 24 | def get_config(): 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument( 27 | "--ckpt-path", 28 | type=Path, 29 | help="checkpoint (.ckpt) file", 30 | ) 31 | parser.add_argument( 32 | "--ckpt-dir", 33 | type=Path, 34 | help="directory contains checkpoint (.ckpt) files", 35 | ) 36 | parser.add_argument( 37 | "--average-top-k", 38 | type=int, 39 | default=None, 40 | help="number of checkpoints to compute checkpoint average weights", 41 | ) 42 | parser.add_argument("--save-dir", type=Path, required=True, help="path to output directory") 43 | parser.add_argument( 44 | "--batch-size", 45 | type=int, 46 | default=1, 47 | help="min-batch size (batch-size=1 is only available when evaluation)", 48 | ) 49 | parser.add_argument( 50 | "--num-workers", 51 | type=int, 52 | default=0, 53 | help="number of workers for dataloader (0 is enough)", 54 | ) 55 | 56 | # dataset 57 | parser.add_argument( 58 | "--corpus", 59 | default="RSTDT", 60 | choices=["RSTDT", "InstrDT"], 61 | help="corpus type (label set is in src/data/dataset.py)", 62 | ) 63 | parser.add_argument( 64 | "--data-dir", 65 | type=Path, 66 | default="data/", 67 | help="dataset directory which contain train/valid/test json files", 68 | ) 69 | parser.add_argument( 70 | "--train-file", 71 | type=Path, 72 | default="train.json", 73 | help="file name of training dataset", 74 | ) 75 | parser.add_argument( 76 | "--valid-file", 77 | type=Path, 78 | default="valid.json", 79 | help="file name of valiation file", 80 | ) 81 | parser.add_argument( 82 | "--test-file", type=Path, default="test.json", help="file name of test dataset" 83 | ) 84 | # metrics 85 | parser.add_argument( 86 | "--metrics", 87 | type=str, 88 | default="RSTParseval", 89 | choices=["RSTParseval", "OriginalParseval"], 90 | help="metrics type to report results", 91 | ) 92 | args = parser.parse_args() 93 | return args 94 | 95 | 96 | def main(): 97 | config = get_config() 98 | device = torch.device("cuda:0") # hard codded 99 | dataset = DataModule.from_config(config, parser=None) 100 | 101 | if config.ckpt_path: 102 | test_single_checkpoint(config, device, dataset) 103 | elif config.ckpt_dir: 104 | test_multiple_checkpoints(config, device, dataset) 105 | 106 | return 107 | 108 | 109 | def test_single_checkpoint(config, device, dataset): 110 | ckpt_path = config.ckpt_path 111 | save_dir = config.save_dir 112 | metrics_type = config.metrics 113 | test(ckpt_path, dataset, save_dir, device, metrics_type) 114 | print("trees of given model are seved into {}".format(save_dir)) 115 | return 116 | 117 | 118 | def test_multiple_checkpoints(config, device, dataset): 119 | # ckpt list 120 | ckpt_path_list = [] 121 | for ckpt_path in config.ckpt_dir.iterdir(): 122 | if ckpt_path.suffix != ".ckpt": 123 | continue 124 | if "average" in str(ckpt_path): 125 | continue 126 | if "last" in str(ckpt_path): 127 | continue 128 | if "best" in str(ckpt_path): 129 | continue 130 | ckpt_path_list.append(ckpt_path) 131 | 132 | metrics_type = config.metrics 133 | 134 | # evaluate each checkpoint 135 | scores = [] 136 | for ckpt_path in ckpt_path_list: 137 | print(ckpt_path) 138 | # ckpt_path: hoge/fuga/epoch=n-step=m.ckpt 139 | model_name = ckpt_path.stem 140 | save_dir = config.save_dir / model_name 141 | valid_score = test(ckpt_path, dataset, save_dir, device, metrics_type) 142 | scores.append({"path": ckpt_path, "score": valid_score}) 143 | 144 | sorted_scores = sorted(scores, reverse=True, key=lambda x: x["score"]["RSTParseval-F"]) 145 | 146 | # select the best checkpoint with the validation score 147 | best_ckpt_path = sorted_scores[0]["path"] 148 | print("the best model: {}".format(best_ckpt_path)) 149 | shutil.copyfile(best_ckpt_path, config.ckpt_dir / "best.ckpt") 150 | print("the best model was saved as {}".format(config.ckpt_dir / "best.ckpt")) 151 | save_dir = config.save_dir / "best" 152 | print("evaluate the best model") 153 | test(best_ckpt_path, dataset, save_dir, device, metrics_type) 154 | print("trees of the best model are seved into {}".format(save_dir)) 155 | 156 | # select top_k models with the validation score 157 | ckpt_path_list_for_avg = [m["path"] for m in sorted_scores[: config.average_top_k]] 158 | print("models for weight average:") 159 | for path in ckpt_path_list_for_avg: 160 | print(" - {}".format(path)) 161 | 162 | avg_ckpt_path = config.ckpt_dir / "average.ckpt" 163 | avg_ckpt = average_checkpoints(ckpt_path_list_for_avg) 164 | torch.save(avg_ckpt, avg_ckpt_path) 165 | print("the averaged model was saved as {}".format(avg_ckpt_path)) 166 | 167 | print("evaluate the averaged model") 168 | save_dir = config.save_dir / "average" 169 | test(avg_ckpt_path, dataset, save_dir, device, metrics_type) 170 | print("trees of the averaged model are seved into {}".format(save_dir)) 171 | 172 | return 173 | 174 | 175 | def test( 176 | ckpt_path: Union[Path, dict], 177 | dataset: DataModule, 178 | save_dir: Path, 179 | device: torch.device, 180 | metrics_type: str, 181 | ): 182 | # load params from checkpoint 183 | if isinstance(ckpt_path, Path): 184 | checkpoint = torch.load(ckpt_path) 185 | 186 | assert "state_dict" in checkpoint 187 | hparams = checkpoint["hyper_parameters"] 188 | 189 | # build classifier with pre-trained weights 190 | model_type = hparams["model_type"] 191 | classifier_class = Classifiers.classifier_dict[model_type] 192 | classifier = classifier_class.load_from_checkpoint(ckpt_path, map_location=device) 193 | 194 | # build parser 195 | parser_class = Parsers.parser_dict[model_type] 196 | parser = parser_class(classifier) 197 | 198 | classifier.eval() 199 | classifier.to(device) 200 | classifier.set_parser(parser) 201 | 202 | # build dataloader 203 | dataset.set_parser(parser) 204 | test_set = dataset.test_dataloader() 205 | valid_set = dataset.val_dataloader()[0] 206 | 207 | metric = { 208 | "RSTParseval": RSTParseval(), 209 | "OriginalParseval": OriginalParseval(), 210 | }[metrics_type] 211 | with torch.no_grad(): 212 | output = parser.parse_dataset(valid_set) 213 | metric.update(output["pred_tree"], output["gold_tree"]) 214 | valid_score = metric.compute() 215 | save_tree(output, save_dir / "valid") 216 | print(valid_score) 217 | metric.reset() 218 | 219 | output = parser.parse_dataset(test_set) 220 | metric.update(output["pred_tree"], output["gold_tree"]) 221 | test_score = metric.compute() 222 | save_tree(output, save_dir / "test") 223 | print(test_score) 224 | metric.reset() 225 | 226 | return valid_score 227 | 228 | 229 | def save_tree(output: Dict, save_dir: Optional[Path] = None): 230 | if save_dir is None: 231 | return 232 | 233 | save_dir.mkdir(parents=True, exist_ok=True) 234 | for doc_id, tree in zip(output["doc_id"], output["pred_tree"]): 235 | assert isinstance(tree, AttachTree) 236 | with open(save_dir / "{}.tree".format(doc_id), "w") as f: 237 | print(tree, file=f) 238 | 239 | return 240 | 241 | 242 | if __name__ == "__main__": 243 | main() 244 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from pathlib import Path 4 | 5 | import pytorch_lightning as pl 6 | import torch.multiprocessing 7 | from pytorch_lightning import seed_everything 8 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint 9 | from pytorch_lightning.callbacks.early_stopping import EarlyStopping 10 | from pytorch_lightning.loggers import TensorBoardLogger 11 | 12 | from data.datamodule import DataModule 13 | from models.classifier import Classifiers 14 | from models.parser import Parsers 15 | 16 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 17 | torch.multiprocessing.set_sharing_strategy("file_system") 18 | 19 | 20 | def main(): 21 | parser = argparse.ArgumentParser() 22 | # dataset 23 | parser.add_argument( 24 | "--corpus", 25 | default="RSTDT", 26 | choices=["RSTDT", "InstrDT"], 27 | help="corpus type (label set is in src/data/dataset.py)", 28 | ) 29 | parser.add_argument( 30 | "--data-dir", 31 | type=Path, 32 | default="data/", 33 | help="dataset directory which contain train/valid/test json files", 34 | ) 35 | parser.add_argument( 36 | "--train-file", 37 | type=Path, 38 | default="train.json", 39 | help="file name of training dataset", 40 | ) 41 | parser.add_argument( 42 | "--valid-file", 43 | type=Path, 44 | default="valid.json", 45 | help="file name of valiation file", 46 | ) 47 | parser.add_argument( 48 | "--test-file", type=Path, default="test.json", help="file name of test dataset" 49 | ) 50 | 51 | # model parameters 52 | parser.add_argument( 53 | "--model-type", 54 | required=True, 55 | choices=[ 56 | "top_down_v1", 57 | "top_down_v2", 58 | "shift_reduce_v1", 59 | "shift_reduce_v2", 60 | "shift_reduce_v3", 61 | ], 62 | help="model type", 63 | ) 64 | parser.add_argument( 65 | "--hidden-dim", 66 | type=int, 67 | default=512, 68 | help="hidden dimention size of classifiers", 69 | ) 70 | parser.add_argument( 71 | "--dropout-p", type=float, default=0.2, help="dropout ratio of classifiers" 72 | ) 73 | parser.add_argument( 74 | "--disable-penalty", 75 | action="store_true", 76 | help="disable a loss penalty method (Koto et al. 2021 in top-down parser)", 77 | ) # for top_down_v1, 2, shift_reduce_v3 78 | parser.add_argument( 79 | "--disable-org-sent", 80 | action="store_true", 81 | help="disable organization features extracted from sentence boundary", 82 | ) 83 | parser.add_argument( 84 | "--disable-org-para", 85 | action="store_true", 86 | help="disable organization features extracted from paragraph boundary", 87 | ) 88 | # bert encoder 89 | parser.add_argument( 90 | "--bert-model-name", 91 | required=True, 92 | choices=[ 93 | "bert-base-cased", 94 | "bert-large-cased", 95 | "roberta-base", 96 | "roberta-large", 97 | "xlnet-base-cased", 98 | "xlnet-large-cased", 99 | "spanbert-base-cased", 100 | "spanbert-large-cased", 101 | "electra-base-discriminator", 102 | "electra-large-discriminator", 103 | "mpnet-base", 104 | "deberta-base", 105 | "deberta-large", 106 | ], 107 | help="encoder type", 108 | ) 109 | parser.add_argument( 110 | "--bert-max-length", 111 | type=int, 112 | default=512, 113 | help="maximum length for sliding window", 114 | ) 115 | parser.add_argument( 116 | "--bert-stride", 117 | type=int, 118 | default=30, 119 | help="stride size for long document" 120 | "(when stride is N, adjacent sliding windows share N sub-tokens)", 121 | ) 122 | 123 | # optimizer setteings 124 | parser.add_argument( 125 | "--lr-for-encoder", 126 | type=float, 127 | default=1e-5, 128 | help="learning rate for weights of pre-trained encoder", 129 | ) 130 | parser.add_argument( 131 | "--lr", 132 | type=float, 133 | default=2e-4, 134 | help="learning rate for random initialized classifiers", 135 | ) 136 | parser.add_argument( 137 | "--disable-lr-schedule", 138 | action="store_true", 139 | help="disable learning rate scheduler", 140 | ) 141 | parser.add_argument( 142 | "--gradient-clip-val", 143 | type=float, 144 | default=1.0, 145 | help="value of gradient clipping", 146 | ) 147 | 148 | # training settings 149 | parser.add_argument("--epochs", type=int, default=20, help="number of epochs for training") 150 | parser.add_argument( 151 | "--batch-unit-type", 152 | choices=["span", "span_fast", "document"], 153 | default="span_fast", 154 | help="unit type of min-batch", 155 | # span: 1 batch is 1 span 156 | # document: 1 batch is 1 document (all spans of the document) 157 | # span_fast: 1 batch is N spans from 1 document (N is batch_size) 158 | ) 159 | parser.add_argument( 160 | "--batch-size", 161 | type=int, 162 | default=5, 163 | help="min batch size (available when batch-unit-type == span_fast)", 164 | ) 165 | parser.add_argument( 166 | "--accumulate-grad-batches", 167 | type=int, 168 | default=1, 169 | help="when batch-unit-type is span/docuemnt," 170 | "make batch size larger by accumulate gradient technique", 171 | ) 172 | parser.add_argument("--train-from", type=Path, default=None, help="path to pre-trained model") 173 | parser.add_argument( 174 | "--disable-span-level-validation", 175 | action="store_true", 176 | help="disable span-level validation to faster training", 177 | ) 178 | parser.add_argument( 179 | "--disable-early-stopping", action="store_true", help="disable early stopping" 180 | ) 181 | parser.add_argument("--patience", type=int, default=5, help="patience of early stopping") 182 | 183 | # save directorie and model name 184 | parser.add_argument("--save-dir", type=Path, required=True, help="path to output") 185 | parser.add_argument( 186 | "--no-save-checkpoint", 187 | action="store_true", 188 | help="disable saving model checkpoint", 189 | ) 190 | parser.add_argument( 191 | "--model-version", type=int, default=None, help="integer value for versioning" 192 | ) 193 | parser.add_argument("--model-name", default=None, help="model's name") 194 | parser.add_argument("--num-gpus", type=int, default=0, help="number of GPUs") 195 | parser.add_argument( 196 | "--num-workers", 197 | type=int, 198 | default=0, 199 | help="number of workers for dataloader (0 is enough)", 200 | ) 201 | 202 | # random seed 203 | parser.add_argument("--seed", type=int, default=None, help="integer value for random seed") 204 | config = parser.parse_args() 205 | 206 | if config.batch_unit_type == "span_fast": 207 | assert config.batch_size > 0 208 | assert config.accumulate_grad_batches == 1 209 | 210 | if config.seed is not None: 211 | seed_everything(config.seed, workers=True) 212 | 213 | classifier = Classifiers.from_config(config) 214 | parser = Parsers.from_config(config, classifier) 215 | classifier.set_parser(parser) # for validation step. 216 | 217 | data_module = DataModule.from_config(config, parser) 218 | training_steps = len(data_module.train_dataloader()) 219 | classifier.set_training_steps_par_epoch(training_steps) 220 | 221 | logger = TensorBoardLogger( 222 | save_dir=config.save_dir, 223 | name=config.model_name, 224 | version=config.model_version, 225 | default_hp_metric=False, 226 | ) 227 | 228 | callbacks = [LearningRateMonitor(logging_interval="step")] 229 | 230 | if not config.disable_early_stopping: 231 | monitor_metric = "valid/OriginalParseval-F" 232 | callbacks.append( 233 | EarlyStopping(monitor=monitor_metric, mode="max", patience=config.patience) 234 | ) 235 | 236 | if config.no_save_checkpoint: 237 | enable_checkpointing = False 238 | else: 239 | enable_checkpointing = True 240 | monitor_metric = "valid/OriginalParseval-F" 241 | callbacks.append( 242 | ModelCheckpoint(monitor=monitor_metric, mode="max", save_last=True, save_top_k=3) 243 | ) 244 | 245 | trainer = pl.Trainer( 246 | max_epochs=config.epochs, 247 | gpus=config.num_gpus, 248 | gradient_clip_val=config.gradient_clip_val, 249 | accumulate_grad_batches=config.accumulate_grad_batches, 250 | default_root_dir=config.save_dir, 251 | resume_from_checkpoint=config.train_from, 252 | # accelerator='ddp', 253 | # plugins=DDPPlugin(find_unused_parameters=True), 254 | callbacks=callbacks, 255 | enable_checkpointing=enable_checkpointing, 256 | logger=logger, 257 | val_check_interval=0.33, # check the validation 3 times per epoch. 258 | reload_dataloaders_every_n_epochs=1 if config.batch_unit_type == "span_fast" else 0, 259 | num_sanity_val_steps=0, 260 | # detect_anomaly=True, 261 | ) 262 | 263 | trainer.fit(classifier, data_module) 264 | 265 | return 266 | 267 | 268 | if __name__ == "__main__": 269 | main() 270 | --------------------------------------------------------------------------------