├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── requirements.txt ├── setup.sh └── src ├── PATH.sh ├── data_scripts ├── PRETRAIN.sh ├── ace2005_joint_er.sh ├── ace2005event.sh ├── ade.sh ├── atis.sh ├── conll04.sh ├── conll05_srl.sh ├── conll12_srl.sh ├── jsonl2json.py ├── multi_woz.sh ├── multi_woz_create_data.py ├── nyt.sh ├── skeleton2conll.py ├── skeleton2conll.sh └── snips.sh ├── dataset_processing ├── ace2005event_types.json ├── arguments.py ├── base_dataset.py ├── config.ini ├── coreference_metrics.py ├── datasets.py ├── evaluate.py ├── input_example.py ├── input_formats.py ├── output_formats.py ├── preprocess_multiwoz │ ├── extract_examples.py │ └── prepare_multi_woz.py ├── run.py └── utils.py ├── download_ckpt.sh ├── glm ├── arguments.py ├── config_tasks │ ├── config.json │ ├── config_mutliserver.json │ ├── model_blocklm_10B_pretrain.sh │ └── pretrain.sh ├── data_utils │ ├── tokenization.py │ └── tokenization_gpt2.py ├── evaluate.py ├── finetune_glm.py ├── generation_utils.py ├── model │ ├── __init__.py │ └── modeling_glm.py ├── pretrain_glm.py ├── scripts │ ├── ds_finetune_seq2seq.sh │ ├── ds_finetune_seq2seq_multiserver.sh │ └── ds_finetune_seq2seq_pretrain.sh ├── tasks │ ├── data_utils.py │ ├── eval_utils.py │ ├── seq2seq │ │ ├── dataset.py │ │ ├── evaluate.py │ │ └── finetune.py │ └── superglue │ │ ├── dataset.py │ │ ├── evaluate.py │ │ └── finetune.py ├── train_utils.py └── zero_shot.py ├── manager.py ├── run_scripts ├── ace2005_jer.sh ├── ace2005event.sh ├── ade.sh ├── atis.sh ├── conll04.sh ├── conll05_srl_brown.sh ├── conll05_srl_wsj.sh ├── conll12_srl.sh ├── multi_woz.sh ├── nyt.sh └── snips.sh └── tasks └── mt ├── ace2005_argument.sh ├── ace2005_ent.sh ├── ace2005_rel.sh ├── ace2005_trigger.sh ├── ade0_ent.sh ├── ade0_rel.sh ├── atis.sh ├── conll04_ent.sh ├── conll04_rel.sh ├── conll05_brown.sh ├── conll05_wsj.sh ├── conll12.sh ├── multi_woz.sh ├── nyt_ent.sh ├── nyt_rel.sh └── snips.sh /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | __pycache__ 3 | glm/runs/ 4 | glm/results.json 5 | glm/logs/ 6 | data_scripts/ 7 | logs/ 8 | scripts/ 9 | tasks/ 10 | running.log 11 | manager.py 12 | PATH.sh 13 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "glm"] 2 | path = glm 3 | url = https://github.com/THUDM/GLM.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2022 XXX 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepStruct: Pretraining of Language Models for Structure Prediction 2 | 3 | Source code repo for paper [DeepStruct: Pretraining of Language Models for Structure Prediction](https://arxiv.org/abs/2205.10475), ACL 2022. 4 | 5 | 6 | ## Setup Environment 7 | 8 | DeepStruct is based on [GLM](https://github.com/THUDM/GLM) dependency. Please use GLM's docker as follow to setup the basic GPU environment (`zxdu20/glm-cuda112` for Ampere GPUs and `zxdu20/glm-cuda102` for older version GPUs such as Tesla V100). 9 | 10 | ```bash 11 | git clone --recursive git@github.com:cgraywang/deepstruct.git 12 | cd ./deepstruct 13 | 14 | docker run --net=host --privileged --pid=host --gpus all --rm -it --ipc=host -v ./deepstruct:/workspace/deepstruct zxdu20/glm-cuda112 15 | cd /workspace/deepstruct 16 | ``` 17 | 18 | and install the dependency via `setup.sh`: 19 | 20 | ```bash 21 | bash setup.sh 22 | ``` 23 | 24 | The final directory structure should be as follows: 25 | 26 | ``` 27 | workspace/ 28 | ├─ deepstruct/ 29 | ├─ data/ 30 | ├─ ckpt/ 31 | ``` 32 | 33 | ## Download Checkpoints 34 | 35 | Most of our experiments are based on 10-billion-parameter DeepStruct checkpoint. Run the following shell scripts to download all multi-task trained DeepStruct checkpoints from huggingface hub (might take a while). 36 | 37 | ```bash 38 | bash download_ckpt.sh 39 | ``` 40 | 41 | ## Data Preparation & Reproduce 42 | 43 | To run following experiments on DeepStruct-10B, our experiments adopt `batch_size_per_gpu=1` and require at least 32 GB GPU memory to run. 44 | The scripts default use `--num-gpus-per-node=1` in `src/tasks/mt/*.sh`, and if you want to use multiple gpu for acceleration, please customize it in `src/tasks/mt/*.sh`. 45 | 46 | Notice that `CoNLL12`, `CoNLL05` for semantic role labeling, `ACE2005` for event extraction require manual download from LDC ([LDC2006T06](https://catalog.ldc.upenn.edu/LDC2006T06), [LDC2013T19](https://catalog.ldc.upenn.edu/LDC2013T19), [PTB-3](https://catalog.ldc.upenn.edu/LDC99T42)). 47 | 48 | | Task | Dataset | Data preparation | Multi-task Result | 49 | |--------------------------------------|---------------|------------------------------------------------------------|-----------------------------------------------------| 50 | | Joint entity and relation extraction | CoNLL04 | `bash run_scripts/conll04.sh` | Ent. 88.4/Rel. 72.8 | 51 | | Joint entity and relation extraction | ADE | `bash run_scripts/ade.sh` | Ent. 90.5/Rel. 83.6 | 52 | | Joint entity and relation extraction | NYT | `bash run_scripts/nyt.sh` | Ent. 95.4/Rel. 93.7 | 53 | | Joint entity and relation extraction | ACE2005 | `bash run_scripts/ace2005_jer.sh ` | Ent. 90.2/Rel. 58.9 | 54 | | Semantic role labeling | CoNLL05 WSJ | `bash run_scripts/conll05_srl_wsj.sh ` | 95.5 | 55 | | Semantic role labeling | CoNLL05 Brown | `bash run_scripts/conll05_srl_brown.sh `| 92.0 | 56 | | Semantic role labeling | CoNLL12 | `bash run_scripts/conll12_srl.sh ` | 97.2 | 57 | | Event extraction | ACE2005 | `bash run_scripts/ace2005event.sh `| Trigger: Id-72.7/Cl-69.2 Argument: Id-67.5/Cl-63.9 | 58 | | Intent detection | ATIS | `bash run_scripts/atis.sh` | 97.3 | 59 | | Intent detection | SNIPS | `bash run_scripts/snips.sh` | 97.4 | 60 | | Dialogue state tracking | MultiWOZ 2.1 | `bash run_scripts/multi_woz.sh` | 53.5 | 61 | 62 | ## Arguments in running scripts 63 | The arguments in `src/tasks/mt/*.sh` configure the training and inference of DeepStruct. Here are their meanings: 64 | 65 | * `--model-type`: the type of model backbone to use. Currently we only support `model_blocklm_10B`, which means using the 10-billion DeepStruct model as the backbone. 66 | * `--model-checkpoint`: the path to the directory of DeepStruct checkpoint. 67 | * `--task`: the task being trained or inferenced. 68 | * `--task-epochs`: number of epochs to run. If set to `0`, it means evaluation only. 69 | * `--length-penalty`: a hyperparameter to configure the lengths of generated sequences in the beam search. 70 | 71 | 72 | ## Scripts for Pretraining 73 | 74 | Following the commands below to prepare pretraining data and run training. 75 | 76 | ```bash 77 | # prepare pretraining data 78 | bash data_scripts/PRETRAIN.sh 79 | 80 | # run pretraining 81 | cd ./glm/ 82 | bash scripts/ds_finetune_seq2seq_pretrain.sh config_tasks/.sh config_tasks/pretrain.sh cnn_dm_original 83 | ``` 84 | 85 | Currently `` supports `model_blocklm_10B_pretrain`, which refers to the 10 billion pretrained model as backbone. 86 | 87 | Please customize `NUM_GPUS_PER_WORKER` in `glm/scripts/ds_finetune_seq2seq_pretrain.sh` and `train_micro_batch_size_per_gpu` in `glm/config_tasks/config.json` according to your environment, as fine-tuning a 10B language model requires quite sufficient GPU memory. 88 | The data preprocessing for pretraining may require over 600G main memory, as the current dataloader implementation preloads all tokenized data into main memory in pretraining. 89 | 90 | ## Citation 91 | 92 | ```bibtex 93 | @inproceedings{wang-etal-2022-deepstruct, 94 | title = "{D}eep{S}truct: Pretraining of Language Models for Structure Prediction", 95 | author = "Wang, Chenguang and 96 | Liu, Xiao and 97 | Chen, Zui and 98 | Hong, Haoyun and 99 | Tang, Jie and 100 | Song, Dawn", 101 | booktitle = "Findings of the Association for Computational Linguistics: ACL 2022", 102 | year = "2022", 103 | } 104 | ``` 105 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | docopt 2 | fuzzywuzzy 3 | gdown 4 | joblib 5 | jsonlines 6 | pyheaven 7 | python-Levenshtein 8 | scikit-learn 9 | transformers==4.11.3 10 | gdown 11 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | cp -rf src/* ./ 2 | pip install -r requirements.txt --user 3 | mkdir ../data 4 | mkdir ../ckpt 5 | -------------------------------------------------------------------------------- /src/PATH.sh: -------------------------------------------------------------------------------- 1 | ROOT=../.. 2 | CHECKPOINTROOT=./ -------------------------------------------------------------------------------- /src/data_scripts/PRETRAIN.sh: -------------------------------------------------------------------------------- 1 | gdown --fuzzy https://drive.google.com/file/d/1xDb9F8cAEx36gPwxmRmqwzWkezUEPJ-i/view -O ../data/v0.5.4.tar.gz 2 | tar -xzvf ../data/v0.5.4.tar.gz -C ../data 3 | mv ../data/v0.5.4 ../data/cnn_dm_original 4 | rm -f ../data/v0.5.4.tar.gz 5 | 6 | wget https://lfs.aminer.cn/misc/cogview/glm-10b-1024.zip 7 | unzip glm-10b-1024.zip -d ../ckpt 8 | mv ../ckpt/glm-10b-1024 ../ckpt/PRETRAIN 9 | -------------------------------------------------------------------------------- /src/data_scripts/ace2005_joint_er.sh: -------------------------------------------------------------------------------- 1 | mkdir -p ../data/ace2005_joint_er/ 2 | ACE05_PATH=$1 3 | if [ ! -n "$ACE05_PATH" ]; then 4 | echo -e "\033[31mPlease untar ace2005 data from https://catalog.ldc.upenn.edu/LDC2006T06 at path ACE05_PATH (absolute path)\033[0m" 5 | exit 1 6 | fi 7 | echo "ACE05_PATH = ${ACE05_PATH}" 8 | ACE05_PATH=$1 9 | git clone https://github.com/dwadden/dygiepp.git 10 | cd ./dygiepp 11 | conda create --name ace-jer-preprocess python=3.7 -y 12 | source activate ace-jer-preprocess 13 | pip install -r requirements.txt 14 | conda develop . 15 | sudo apt-get update 16 | sudo apt-get install zsh -y 17 | sudo apt-get install openjdk-8-jdk -y 18 | # Please prepare ace2005 data from https://catalog.ldc.upenn.edu/LDC2006T06 at path ACE05_PATH 19 | bash scripts/data/ace05/get_corenlp.sh 20 | bash scripts/data/get_ace05.sh ${ACE05_PATH} 21 | cp -rf ./data/ace05/processed-data/json/* ../../data/ace2005_joint_er/ 22 | cd ../ 23 | rm -rf dygiepp 24 | cp -r ../data/ace2005_joint_er ../data/ace2005_joint_er_re 25 | source deactivate 26 | -------------------------------------------------------------------------------- /src/data_scripts/ace2005event.sh: -------------------------------------------------------------------------------- 1 | mkdir -p ../data/ace2005event_trigger/ 2 | mkdir -p ../data/ace2005event_argument/ 3 | ACE05_PATH=$1 4 | if [ ! -n "$ACE05_PATH" ]; then 5 | echo -e "\033[31mPlease untar ace2005 data from https://catalog.ldc.upenn.edu/LDC2006T06 at path ACE05_PATH (absolute path)\033[0m" 6 | exit 1 7 | fi 8 | echo "ACE05_PATH = ${ACE05_PATH}" 9 | 10 | git clone https://github.com/dwadden/dygiepp.git 11 | cd ./dygiepp 12 | conda create --name ace-event-preprocess -y python=3.7 13 | source activate ace-event-preprocess 14 | python -m pip install -r requirements.txt 15 | python -m pip install -r scripts/data/ace-event/requirements.txt 16 | python -m spacy download en_core_web_sm 17 | # Please prepare ace2005 data from https://catalog.ldc.upenn.edu/LDC2006T06 at path ACE05_PATH 18 | bash ./scripts/data/ace-event/collect_ace_event.sh ${ACE05_PATH} 19 | python ./scripts/data/ace-event/parse_ace_event.py default-settings 20 | mkdir -p data/ace-event/collated-data/default-settings/json 21 | python -m scripts.data.shared.collate \ 22 | data/ace-event/processed-data/default-settings/json \ 23 | data/ace-event/collated-data/default-settings/json \ 24 | --file_extension json 25 | cp -rf ./data/ace-event/processed-data/default-settings/json/* ../../data/ace2005event_trigger/ 26 | cp -rf ./data/ace-event/processed-data/default-settings/json/* ../../data/ace2005event_argument/ 27 | cd ../ 28 | rm -rf dygiepp 29 | source deactivate 30 | python data_scripts/jsonl2json.py -i ../data/ace2005event_trigger/train.json -o ../data/ace2005event_trigger/ace2005event_train.json 31 | python data_scripts/jsonl2json.py -i ../data/ace2005event_trigger/dev.json -o ../data/ace2005event_trigger/ace2005event_dev.json 32 | python data_scripts/jsonl2json.py -i ../data/ace2005event_trigger/test.json -o ../data/ace2005event_trigger/ace2005event_test.json 33 | python data_scripts/jsonl2json.py -i ../data/ace2005event_argument/train.json -o ../data/ace2005event_argument/ace2005event_train.json 34 | python data_scripts/jsonl2json.py -i ../data/ace2005event_argument/dev.json -o ../data/ace2005event_argument/ace2005event_dev.json 35 | python data_scripts/jsonl2json.py -i ../data/ace2005event_argument/test.json -o ../data/ace2005event_argument/ace2005event_test.json 36 | 37 | -------------------------------------------------------------------------------- /src/data_scripts/ade.sh: -------------------------------------------------------------------------------- 1 | mkdir -p ../data/ade/ 2 | wget -r -nH --cut-dirs=100 --reject "index.html*" --no-parent http://lavis.cs.hs-rm.de/storage/spert/public/datasets/ade/ -P ../data/ade/ 3 | cp -rf ../data/ade/ ../data/ade_re 4 | mv ../data/ade_re/ade_full.json ../data/ade_re/ade_re_full.json 5 | mv ../data/ade_re/ade_split_0_train.json ../data/ade_re/ade_re_split_0_train.json 6 | mv ../data/ade_re/ade_split_0_test.json ../data/ade_re/ade_re_split_0_test.json 7 | mv ../data/ade_re/ade_split_1_train.json ../data/ade_re/ade_re_split_1_train.json 8 | mv ../data/ade_re/ade_split_1_test.json ../data/ade_re/ade_re_split_1_test.json 9 | mv ../data/ade_re/ade_split_2_train.json ../data/ade_re/ade_re_split_2_train.json 10 | mv ../data/ade_re/ade_split_2_test.json ../data/ade_re/ade_re_split_2_test.json 11 | mv ../data/ade_re/ade_split_3_train.json ../data/ade_re/ade_re_split_3_train.json 12 | mv ../data/ade_re/ade_split_3_test.json ../data/ade_re/ade_re_split_3_test.json 13 | mv ../data/ade_re/ade_split_4_train.json ../data/ade_re/ade_re_split_4_train.json 14 | mv ../data/ade_re/ade_split_4_test.json ../data/ade_re/ade_re_split_4_test.json 15 | mv ../data/ade_re/ade_split_5_train.json ../data/ade_re/ade_re_split_5_train.json 16 | mv ../data/ade_re/ade_split_5_test.json ../data/ade_re/ade_re_split_5_test.json 17 | mv ../data/ade_re/ade_split_6_train.json ../data/ade_re/ade_re_split_6_train.json 18 | mv ../data/ade_re/ade_split_6_test.json ../data/ade_re/ade_re_split_6_test.json 19 | mv ../data/ade_re/ade_split_7_train.json ../data/ade_re/ade_re_split_7_train.json 20 | mv ../data/ade_re/ade_split_7_test.json ../data/ade_re/ade_re_split_7_test.json 21 | mv ../data/ade_re/ade_split_8_train.json ../data/ade_re/ade_re_split_8_train.json 22 | mv ../data/ade_re/ade_split_8_test.json ../data/ade_re/ade_re_split_8_test.json 23 | mv ../data/ade_re/ade_split_9_train.json ../data/ade_re/ade_re_split_9_train.json 24 | mv ../data/ade_re/ade_split_9_test.json ../data/ade_re/ade_re_split_9_test.json -------------------------------------------------------------------------------- /src/data_scripts/atis.sh: -------------------------------------------------------------------------------- 1 | rm -r ../data/atis 2 | git clone https://github.com/90217/joint-intent-classification-and-slot-filling-based-on-BERT.git 3 | mv joint-intent-classification-and-slot-filling-based-on-BERT/data/atis ../data/atis 4 | mv ../data/atis/valid ../data/atis/dev 5 | rm -r joint-intent-classification-and-slot-filling-based-on-BERT 6 | -------------------------------------------------------------------------------- /src/data_scripts/conll04.sh: -------------------------------------------------------------------------------- 1 | mkdir -p ../data/conll04/ 2 | wget -r -nH --cut-dirs=100 --reject "index.html*" --no-parent http://lavis.cs.hs-rm.de/storage/spert/public/datasets/conll04/ -P ../data/conll04/ 3 | cp -rf ../data/conll04/ ../data/conll04_re 4 | mv ../data/conll04_re/conll04_train.json ../data/conll04_re/conll04_re_train.json 5 | mv ../data/conll04_re/conll04_dev.json ../data/conll04_re/conll04_re_dev.json 6 | mv ../data/conll04_re/conll04_test.json ../data/conll04_re/conll04_re_test.json -------------------------------------------------------------------------------- /src/data_scripts/conll05_srl.sh: -------------------------------------------------------------------------------- 1 | PTB_PATH=$1 2 | echo "check_certificate = off" >> ~/.wgetrc 3 | if [ ! -n "$PTB_PATH" ]; then 4 | echo -e "\033[31mPlease untar penn treebank iii data from https://catalog.ldc.upenn.edu/LDC99T42 to path PTB_PATH (absolute path)\033[0m" 5 | exit 1 6 | fi 7 | git clone https://github.com/luheng/deep_srl.git 8 | cd ./deep_srl/ 9 | 10 | # prepare srlconll-1.1.tgz 11 | SRLPATH="./data/srl" 12 | if [ ! -d $SRLPATH ]; then 13 | mkdir -p $SRLPATH 14 | fi 15 | 16 | # Get srl-conll package. 17 | wget -O "${SRLPATH}/srlconll-1.1.tgz" --no-check-certificate http://www.lsi.upc.edu/~srlconll/srlconll-1.1.tgz 18 | tar xf "${SRLPATH}/srlconll-1.1.tgz" -C "${SRLPATH}" 19 | rm "${SRLPATH}/srlconll-1.1.tgz" 20 | sudo apt-get install tcsh 21 | 22 | sudo chmod +x ./scripts/fetch_and_make_conll05_data.sh 23 | bash ./scripts/fetch_and_make_conll05_data.sh $PTB_PATH 24 | cd ../ 25 | cp -rf ./deep_srl/data/srl ../data/ 26 | mv ../data/srl ../data/conll05_srl 27 | cp -rf ../data/conll05_srl ../data/conll05_srl_wsj 28 | cp -rf ../data/conll05_srl ../data/conll05_srl_brown 29 | rm -rf ./deep_srl 30 | 31 | cp ../data/conll05_srl/conll05.devel.txt ../data/conll05_srl/conll05.dev.txt 32 | cp ../data/conll05_srl_wsj/conll05.test.wsj.txt ../data/conll05_srl_wsj/conll05.test.txt 33 | cp ../data/conll05_srl_brown/conll05.test.brown.txt ../data/conll05_srl_brown/conll05.test.txt 34 | -------------------------------------------------------------------------------- /src/data_scripts/conll12_srl.sh: -------------------------------------------------------------------------------- 1 | ONTONOTES_PATH=$1 2 | if [ ! -n "$ONTONOTES_PATH" ]; then 3 | echo -e "\033[31mPlease unzip conll12srl data from https://catalog.ldc.upenn.edu/LDC2013T19 at path ONTONOTES_PATH (absolute path)\033[0m" 4 | exit 1 5 | fi 6 | echo "ONTONOTES_PATH = ${ONTONOTES_PATH}" 7 | 8 | # Prepare conll12 sample ids in each splits 9 | wget -P ../data https://github.com/ontonotes/conll-formatted-ontonotes-5.0/archive/refs/tags/v12.tar.gz 10 | cd ../data 11 | tar -xvzf v12.tar.gz 12 | cd ../deepstruct 13 | 14 | 15 | conda create -n python2 -y python=2.7.18 16 | source activate python2 17 | # Change format into CoNLL format 18 | cd ./data_scripts 19 | bash skeleton2conll.sh -D "${ONTONOTES_PATH}/data/files/data" ../../data/conll-formatted-ontonotes-5.0-12/conll-formatted-ontonotes-5.0/ 20 | cd .. 21 | 22 | git clone git@github.com:luheng/deep_srl.git 23 | cd ./deep_srl/ 24 | bash ./scripts/make_conll2012_data.sh ../../data/conll-formatted-ontonotes-5.0-12/conll-formatted-ontonotes-5.0/ 25 | source activate base 26 | cd ../ 27 | cp -rf ./deep_srl/data/srl ../data/ 28 | rm -rf ../data/conll12_srl 29 | mv ../data/srl ../data/conll12_srl 30 | rm -rf ./deep_srl 31 | mv ../data/conll12_srl/conll2012.devel.txt ../data/conll12_srl/conll2012.dev.txt 32 | -------------------------------------------------------------------------------- /src/data_scripts/jsonl2json.py: -------------------------------------------------------------------------------- 1 | from pyheaven import * 2 | 3 | if __name__=="__main__": 4 | args = HeavenArguments.from_parser([ 5 | StrArgumentDescriptor("input",short="i",default=None), 6 | StrArgumentDescriptor("output",short="o",default=None), 7 | StrArgumentDescriptor("indent",short="t",default=None), 8 | ]) 9 | if args.input is not None and args.output is not None: 10 | SaveJson(LoadJson(args.input,backend='jsonl'),args.output,indent=args.indent) -------------------------------------------------------------------------------- /src/data_scripts/multi_woz.sh: -------------------------------------------------------------------------------- 1 | git clone https://github.com/jasonwu0731/trade-dst.git 2 | cp ./data_scripts/multi_woz_create_data.py ./trade-dst/ 3 | cd ./trade-dst/ 4 | python multi_woz_create_data.py 5 | cd ../ 6 | cp ./trade-dst/utils/fix_label.py ./dataset_processing/preprocess_multiwoz/ 7 | cd ./dataset_processing/preprocess_multiwoz/ 8 | python prepare_multi_woz.py --data-dir ../../trade-dst/data 9 | cd ../../ 10 | cp -rf ./trade-dst/data/splits ../data/ 11 | mv ../data/splits ../data/multi_woz 12 | rm -rf ./trade-dst 13 | -------------------------------------------------------------------------------- /src/data_scripts/nyt.sh: -------------------------------------------------------------------------------- 1 | mkdir -p ../data/nyt/ 2 | git clone https://github.com/yubowen-ph/JointER.git 3 | rm -r ../data/nyt 4 | mv JointER/dataset/NYT-multi/data ../data/nyt 5 | rm -r JointER 6 | ~/.local/bin/gdown --fuzzy https://drive.google.com/file/d/1kguS3pHc7F0NmjJvU-aSmqvKPYAeaAKk/view -O ../data/nyt/schemas.json 7 | cp -rf ../data/nyt/ ../data/nyt_re 8 | -------------------------------------------------------------------------------- /src/data_scripts/skeleton2conll.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | function usage { 4 | cat < 13 | 14 | 15 | Description: 16 | ----------- 17 | 18 | : Location of the data directory under the OntoNotes 4.0 release 19 | : The directory inside which the *_skel files exist and need to 20 | be convered to .conll files 21 | 22 | ---------------------------------------------------------------------------------------------------- 23 | 24 | 25 | 26 | 27 | EOF 28 | exit; 29 | } 30 | 31 | 32 | function message 33 | { 34 | (echo "----------------------------------------------------------------------------------------------------"; 35 | echo "" ; 36 | echo $* ; 37 | echo "" ; 38 | echo "----------------------------------------------------------------------------------------------------") 1>&2 39 | 40 | } 41 | 42 | function warning 43 | { 44 | message "$*" 45 | } 46 | 47 | function error 48 | { 49 | message "$*" 50 | exit 51 | } 52 | 53 | 54 | function r { echo ${1%.*}; } 55 | function t { echo ${1##*/}; } 56 | function e { echo $(t ${1##*.}); } 57 | function h { echo ${1%/*}; } 58 | 59 | 60 | 61 | # define helper function: run a command and print its exit code 62 | function erun () 63 | { 64 | local debug; 65 | local verbose; 66 | debug=0; 67 | if [[ $1 == "-d" ]]; then 68 | debug=1; 69 | shift; 70 | fi; 71 | verbose=0; 72 | if [[ $1 == "-v" ]]; then 73 | verbose=1; 74 | shift; 75 | fi; 76 | if [[ $DEBUG -eq 1 ]]; then 77 | debug=1; 78 | fi; 79 | if [[ $VERBOSE -eq 1 ]]; then 80 | verbose=1; 81 | fi; 82 | if [[ $debug -eq 1 ]]; then 83 | echo "eval $1"; 84 | else 85 | if [[ $verbose -eq 1 ]]; then 86 | echo "-> $1"; 87 | fi; 88 | eval $1; 89 | fi; 90 | local code=$?; 91 | if [ $code -ne 0 ]; then 92 | echo "Exit code: $code"; 93 | break; 94 | fi 95 | } 96 | 97 | 98 | 99 | # handle the valid command line options 100 | DEBUG=0 101 | TESTING=false 102 | VERBOSE=0 103 | DEBUG_OPTION="" 104 | EDITED="" 105 | while getopts D:dhT opt 106 | do 107 | case "$opt" in 108 | v) 109 | VERBOSE=1;; 110 | 111 | d) 112 | DEBUG=1 113 | DEBUG_OPTION="-d";; 114 | 115 | D) 116 | ON_DATA_DIR="$OPTARG" 117 | ON_DATA_DIR=${ON_DATA_DIR%/} 118 | 119 | if [[ -z $ON_DATA_DIR ]]; then 120 | error "please specify a valid ontonotes data directory using the -D option" 121 | usage 122 | fi;; 123 | 124 | T) 125 | # this option is used internally for testing 126 | TESTING=true;; 127 | 128 | \?) 129 | usage 130 | exit 1;; 131 | 132 | h) 133 | usage 134 | exit 0;; 135 | 136 | :) 137 | echo "option -$OPTARG requires an argument" 138 | usage 139 | exit 1;; 140 | 141 | esac 142 | done 143 | shift `expr $OPTIND - 1` 144 | 145 | 146 | 147 | 148 | # at this point $* contains the arguments after interpreting the options 149 | 150 | d=$1 151 | d=${d%/} 152 | 153 | 154 | # if the conll release directory is not correct 155 | if [[ $(t $d) != "conll-formatted-ontonotes-5.0" ]]; then 156 | error "please make sure that you are pointing to the directory 'conll-formatted-ontonotes-5.0'" 157 | fi 158 | 159 | 160 | 161 | # if we are testing the release, we do not want to clobber the 162 | # true _conll files 163 | if $TESTING; then 164 | EXT="_skel2conll" 165 | else 166 | EXT="_conll" 167 | fi 168 | 169 | 170 | # if no arguments are specified, then just print usage 171 | if [[ $# -eq 0 ]]; then 172 | usage 173 | fi 174 | 175 | 176 | 177 | 178 | for language in arabic english chinese; do 179 | # set the EDITED option only for english 180 | if [[ $language == "english" ]]; then 181 | EDITED="-edited" 182 | else 183 | EDITED="" 184 | fi 185 | 186 | for partition in train development test conll-2012-test; do 187 | 188 | if [[ -d $d/data/$partition/data/$language/ ]]; then 189 | for skel in $(find $d/data/$partition/data/$language/ -name "*_skel"); do 190 | gold_parse=$ON_DATA_DIR/$(r ${skel/*data\//}).parse 191 | 192 | if [[ ! -e $gold_parse ]]; then 193 | error "could not find the gold parse [$gold_parse] in the ontonotes distribution ... exiting ..." 194 | exit 195 | fi 196 | 197 | conll=${skel/_skel/$EXT} 198 | erun -v "python skeleton2conll.py $gold_parse $skel $conll $EDITED --text" 199 | done 200 | fi 201 | done 202 | done 203 | 204 | 205 | 206 | 207 | # complain if the exit status of the last command executed is non-zero 208 | if [[ $? != 0 ]]; then echo "the last command exited with a non-zero status" 1>&2; fi 209 | 210 | 211 | -------------------------------------------------------------------------------- /src/data_scripts/snips.sh: -------------------------------------------------------------------------------- 1 | mkdir -p ../data/snips/ 2 | mkdir -p ../data/snips/test/ 3 | mkdir -p ../data/snips/train/ 4 | mkdir -p ../data/snips/dev/ 5 | wget https://raw.githubusercontent.com/90217/joint-intent-classification-and-slot-filling-based-on-BERT/master/data/snips/test/label -P ../data/snips/test 6 | wget https://raw.githubusercontent.com/90217/joint-intent-classification-and-slot-filling-based-on-BERT/master/data/snips/test/seq.in -P ../data/snips/test 7 | wget https://raw.githubusercontent.com/90217/joint-intent-classification-and-slot-filling-based-on-BERT/master/data/snips/test/seq.out -P ../data/snips/test 8 | wget https://raw.githubusercontent.com/90217/joint-intent-classification-and-slot-filling-based-on-BERT/master/data/snips/train/label -P ../data/snips/train 9 | wget https://raw.githubusercontent.com/90217/joint-intent-classification-and-slot-filling-based-on-BERT/master/data/snips/train/seq.in -P ../data/snips/train 10 | wget https://raw.githubusercontent.com/90217/joint-intent-classification-and-slot-filling-based-on-BERT/master/data/snips/train/seq.out -P ../data/snips/train 11 | wget https://raw.githubusercontent.com/90217/joint-intent-classification-and-slot-filling-based-on-BERT/master/data/snips/valid/label -P ../data/snips/dev 12 | wget https://raw.githubusercontent.com/90217/joint-intent-classification-and-slot-filling-based-on-BERT/master/data/snips/valid/seq.in -P ../data/snips/dev 13 | wget https://raw.githubusercontent.com/90217/joint-intent-classification-and-slot-filling-based-on-BERT/master/data/snips/valid/seq.out -P ../data/snips/dev 14 | -------------------------------------------------------------------------------- /src/dataset_processing/ace2005event_types.json: -------------------------------------------------------------------------------- 1 | { 2 | "entities": { 3 | "VEH:Underspecified": { 4 | "short": "VEH:Underspecified", 5 | "verbose": "underspecified vehicle" 6 | }, 7 | "Justice:Execute": { 8 | "short": "Justice:Execute", 9 | "verbose": "execute" 10 | }, 11 | "Life:Die": { 12 | "short": "Life:Die", 13 | "verbose": "die" 14 | }, 15 | "Business:Merge-Org": { 16 | "short": "Business:Merge-Org", 17 | "verbose": "merge organization" 18 | }, 19 | "Conflict:Attack": { 20 | "short": "Conflict:Attack", 21 | "verbose": "attack" 22 | }, 23 | "Justice:Arrest-Jail": { 24 | "short": "Justice:Arrest-Jail", 25 | "verbose": "arrest jail" 26 | }, 27 | "Numeric:Money": { 28 | "short": "Numeric:Money", 29 | "verbose": "money" 30 | }, 31 | "Personnel:Start-Position": { 32 | "short": "Personnel:Start-Position", 33 | "verbose": "start position" 34 | }, 35 | "VEH:Air": { 36 | "short": "VEH:Air", 37 | "verbose": "air vehicle" 38 | }, 39 | "ORG:Religious": { 40 | "short": "ORG:Religious", 41 | "verbose": "religious organization" 42 | }, 43 | "Personnel:End-Position": { 44 | "short": "Personnel:End-Position", 45 | "verbose": "end position" 46 | }, 47 | "VEH:Water": { 48 | "short": "VEH:Water", 49 | "verbose": "water vehicle" 50 | }, 51 | "Justice:Appeal": { 52 | "short": "Justice:Appeal", 53 | "verbose": "appeal" 54 | }, 55 | "Transaction:Transfer-Money": { 56 | "short": "Transaction:Transfer-Money", 57 | "verbose": "transfer money" 58 | }, 59 | "WEA:Exploding": { 60 | "short": "WEA:Exploding", 61 | "verbose": "exploding weapon" 62 | }, 63 | "WEA:Shooting": { 64 | "short": "WEA:Shooting", 65 | "verbose": "shooting weapon" 66 | }, 67 | "Contact:Meet": { 68 | "short": "Contact:Meet", 69 | "verbose": "meet" 70 | }, 71 | "WEA:Projectile": { 72 | "short": "WEA:Projectile", 73 | "verbose": "projectile weapon" 74 | }, 75 | "WEA:Sharp": { 76 | "short": "WEA:Sharp", 77 | "verbose": "sharp weapon" 78 | }, 79 | "ORG:Non-Governmental": { 80 | "short": "ORG:Non-Governmental", 81 | "verbose": "non governmental organization" 82 | }, 83 | "Numeric:Percent": { 84 | "short": "Numeric:Percent", 85 | "verbose": "percent" 86 | }, 87 | "ORG:Educational": { 88 | "short": "ORG:Educational", 89 | "verbose": "educational organization" 90 | }, 91 | "FAC:Building-Grounds": { 92 | "short": "FAC:Building-Grounds", 93 | "verbose": "building grounds" 94 | }, 95 | "WEA:Chemical": { 96 | "short": "WEA:Chemical", 97 | "verbose": "chemical weapon" 98 | }, 99 | "Transaction:Transfer-Ownership": { 100 | "short": "Transaction:Transfer-Ownership", 101 | "verbose": "transfer ownership" 102 | }, 103 | "LOC:Region-General": { 104 | "short": "LOC:Region-General", 105 | "verbose": "region general location" 106 | }, 107 | "Business:End-Org": { 108 | "short": "Business:End-Org", 109 | "verbose": "end organization" 110 | }, 111 | "Personnel:Elect": { 112 | "short": "Personnel:Elect", 113 | "verbose": "elect" 114 | }, 115 | "FAC:Path": { 116 | "short": "FAC:Path", 117 | "verbose": "path" 118 | }, 119 | "Justice:Trial-Hearing": { 120 | "short": "Justice:Trial-Hearing", 121 | "verbose": "trial hearing" 122 | }, 123 | "Conflict:Demonstrate": { 124 | "short": "Conflict:Demonstrate", 125 | "verbose": "demonstrate" 126 | }, 127 | "WEA:Underspecified": { 128 | "short": "WEA:Underspecified", 129 | "verbose": "underspecified weapon" 130 | }, 131 | "Contact-Info:URL": { 132 | "short": "Contact-Info:URL", 133 | "verbose": "url" 134 | }, 135 | "PER:Individual": { 136 | "short": "PER:Individual", 137 | "verbose": "individual" 138 | }, 139 | "Justice:Acquit": { 140 | "short": "Justice:Acquit", 141 | "verbose": "acquit" 142 | }, 143 | "ORG:Government": { 144 | "short": "ORG:Government", 145 | "verbose": "government" 146 | }, 147 | "PER:Indeterminate": { 148 | "short": "PER:Indeterminate", 149 | "verbose": "indeterminate" 150 | }, 151 | "LOC:Celestial": { 152 | "short": "LOC:Celestial", 153 | "verbose": "celestial" 154 | }, 155 | "Life:Be-Born": { 156 | "short": "Life:Be-Born", 157 | "verbose": "be born" 158 | }, 159 | "Business:Declare-Bankruptcy": { 160 | "short": "Business:Declare-Bankruptcy", 161 | "verbose": "declare bankruptcy" 162 | }, 163 | "FAC:Airport": { 164 | "short": "FAC:Airport", 165 | "verbose": "airport" 166 | }, 167 | "WEA:Biological": { 168 | "short": "WEA:Biological", 169 | "verbose": "biological weapon" 170 | }, 171 | "VEH:Subarea-Vehicle": { 172 | "short": "VEH:Subarea-Vehicle", 173 | "verbose": "subarea vehicle" 174 | }, 175 | "LOC:Water-Body": { 176 | "short": "LOC:Water-Body", 177 | "verbose": "water body location" 178 | }, 179 | "Life:Injure": { 180 | "short": "Life:Injure", 181 | "verbose": "injure" 182 | }, 183 | "LOC:Land-Region-Natural": { 184 | "short": "LOC:Land-Region-Natural", 185 | "verbose": "land region natural" 186 | }, 187 | "PER:Group": { 188 | "short": "PER:Group", 189 | "verbose": "group" 190 | }, 191 | "Justice:Fine": { 192 | "short": "Justice:Fine", 193 | "verbose": "fine" 194 | }, 195 | "Business:Start-Org": { 196 | "short": "Business:Start-Org", 197 | "verbose": "start organization" 198 | }, 199 | "Justice:Sue": { 200 | "short": "Justice:Sue", 201 | "verbose": "sue" 202 | }, 203 | "Personnel:Nominate": { 204 | "short": "Personnel:Nominate", 205 | "verbose": "nominate" 206 | }, 207 | "FAC:Plant": { 208 | "short": "FAC:Plant", 209 | "verbose": "plant" 210 | }, 211 | "ORG:Medical-Science": { 212 | "short": "ORG:Medical-Science", 213 | "verbose": "medical science" 214 | }, 215 | "Justice:Release-Parole": { 216 | "short": "Justice:Release-Parole", 217 | "verbose": "release parole" 218 | }, 219 | "GPE:Continent": { 220 | "short": "GPE:Continent", 221 | "verbose": "continent" 222 | }, 223 | "Life:Divorce": { 224 | "short": "Life:Divorce", 225 | "verbose": "divorce" 226 | }, 227 | "Justice:Convict": { 228 | "short": "Justice:Convict", 229 | "verbose": "convict" 230 | }, 231 | "LOC:Address": { 232 | "short": "LOC:Address", 233 | "verbose": "address" 234 | }, 235 | "ORG:Commercial": { 236 | "short": "ORG:Commercial", 237 | "verbose": "commercial organization" 238 | }, 239 | "Justice:Sentence": { 240 | "short": "Justice:Sentence", 241 | "verbose": "sentence" 242 | }, 243 | "ORG:Media": { 244 | "short": "ORG:Media", 245 | "verbose": "media" 246 | }, 247 | "Movement:Transport": { 248 | "short": "Movement:Transport", 249 | "verbose": "transport" 250 | }, 251 | "FAC:Subarea-Facility": { 252 | "short": "FAC:Subarea-Facility", 253 | "verbose": "subarea facility" 254 | }, 255 | "LOC:Region-International": { 256 | "short": "LOC:Region-International", 257 | "verbose": "region international" 258 | }, 259 | "Sentence": { 260 | "short": "Sentence", 261 | "verbose": "sentence" 262 | }, 263 | "ORG:Entertainment": { 264 | "short": "ORG:Entertainment", 265 | "verbose": "entertainment organization" 266 | }, 267 | "LOC:Boundary": { 268 | "short": "LOC:Boundary", 269 | "verbose": "boundary" 270 | }, 271 | "Justice:Charge-Indict": { 272 | "short": "Justice:Charge-Indict", 273 | "verbose": "charge indict" 274 | }, 275 | "Crime": { 276 | "short": "Crime", 277 | "verbose": "crime" 278 | }, 279 | "WEA:Blunt": { 280 | "short": "WEA:Blunt", 281 | "verbose": "blunt weapon" 282 | }, 283 | "Contact-Info:Phone-Number": { 284 | "short": "Contact-Info:Phone-Number", 285 | "verbose": "phone number" 286 | }, 287 | "WEA:Nuclear": { 288 | "short": "WEA:Nuclear", 289 | "verbose": "nuclear weapon" 290 | }, 291 | "ORG:Sports": { 292 | "short": "ORG:Sports", 293 | "verbose": "sports" 294 | }, 295 | "Job-Title": { 296 | "short": "Job-Title", 297 | "verbose": "job title" 298 | }, 299 | "GPE:Special": { 300 | "short": "GPE:Special", 301 | "verbose": "special location" 302 | }, 303 | "VEH:Land": { 304 | "short": "VEH:Land", 305 | "verbose": "land vehicle" 306 | }, 307 | "Contact:Phone-Write": { 308 | "short": "Contact:Phone-Write", 309 | "verbose": "phone write" 310 | }, 311 | "GPE:Nation": { 312 | "short": "GPE:Nation", 313 | "verbose": "nation" 314 | }, 315 | "GPE:State-or-Province": { 316 | "short": "GPE:State-or-Province", 317 | "verbose": "state or province" 318 | }, 319 | "Justice:Pardon": { 320 | "short": "Justice:Pardon", 321 | "verbose": "pardon" 322 | }, 323 | "GPE:GPE-Cluster": { 324 | "short": "GPE:GPE-Cluster", 325 | "verbose": "gpe cluster" 326 | }, 327 | "Life:Marry": { 328 | "short": "Life:Marry", 329 | "verbose": "marry" 330 | }, 331 | "TIM:time": { 332 | "short": "TIM:time", 333 | "verbose": "time" 334 | }, 335 | "Contact-Info:E-Mail": { 336 | "short": "Contact-Info:E-Mail", 337 | "verbose": "email" 338 | }, 339 | "GPE:County-or-District": { 340 | "short": "GPE:County-or-District", 341 | "verbose": "county or district" 342 | }, 343 | "Justice:Extradite": { 344 | "short": "Justice:Extradite", 345 | "verbose": "extradite" 346 | }, 347 | "GPE:Population-Center": { 348 | "short": "GPE:Population-Center", 349 | "verbose": "population center" 350 | } 351 | }, 352 | "relations": { 353 | "Time-Before": { 354 | "short": "Time-Before", 355 | "verbose": "time before", 356 | "symmetric": false 357 | }, 358 | "Agent": { 359 | "short": "Agent", 360 | "verbose": "agent", 361 | "symmetric": false 362 | }, 363 | "Buyer": { 364 | "short": "Buyer", 365 | "verbose": "buyer", 366 | "symmetric": false 367 | }, 368 | "Prosecutor": { 369 | "short": "Prosecutor", 370 | "verbose": "prosecutor", 371 | "symmetric": false 372 | }, 373 | "Time-Within": { 374 | "short": "Time-Within", 375 | "verbose": "time within", 376 | "symmetric": false 377 | }, 378 | "Money": { 379 | "short": "Money", 380 | "verbose": "money", 381 | "symmetric": false 382 | }, 383 | "Entity": { 384 | "short": "Entity", 385 | "verbose": "entity", 386 | "symmetric": false 387 | }, 388 | "Person": { 389 | "short": "Person", 390 | "verbose": "person", 391 | "symmetric": false 392 | }, 393 | "Seller": { 394 | "short": "Seller", 395 | "verbose": "seller", 396 | "symmetric": false 397 | }, 398 | "Time-At-Beginning": { 399 | "short": "Time-At-Beginning", 400 | "verbose": "time at beginning", 401 | "symmetric": false 402 | }, 403 | "Target": { 404 | "short": "Target", 405 | "verbose": "target", 406 | "symmetric": false 407 | }, 408 | "Place": { 409 | "short": "Place", 410 | "verbose": "place", 411 | "symmetric": false 412 | }, 413 | "Artifact": { 414 | "short": "Artifact", 415 | "verbose": "artifact", 416 | "symmetric": false 417 | }, 418 | "Origin": { 419 | "short": "Origin", 420 | "verbose": "origin", 421 | "symmetric": false 422 | }, 423 | "Instrument": { 424 | "short": "Instrument", 425 | "verbose": "instrument", 426 | "symmetric": false 427 | }, 428 | "Beneficiary": { 429 | "short": "Beneficiary", 430 | "verbose": "beneficiary", 431 | "symmetric": false 432 | }, 433 | "Destination": { 434 | "short": "Destination", 435 | "verbose": "destination", 436 | "symmetric": false 437 | }, 438 | "Recipient": { 439 | "short": "Recipient", 440 | "verbose": "recipient", 441 | "symmetric": false 442 | }, 443 | "Time-Ending": { 444 | "short": "Time-Ending", 445 | "verbose": "time ending", 446 | "symmetric": false 447 | }, 448 | "Org": { 449 | "short": "Org", 450 | "verbose": "org", 451 | "symmetric": false 452 | }, 453 | "Time-At-End": { 454 | "short": "Time-At-End", 455 | "verbose": "time at end", 456 | "symmetric": false 457 | }, 458 | "Vehicle": { 459 | "short": "Vehicle", 460 | "verbose": "vehicle", 461 | "symmetric": false 462 | }, 463 | "Adjudicator": { 464 | "short": "Adjudicator", 465 | "verbose": "adjudicator", 466 | "symmetric": false 467 | }, 468 | "Sentence": { 469 | "short": "Sentence", 470 | "verbose": "sentence", 471 | "symmetric": false 472 | }, 473 | "Crime": { 474 | "short": "Crime", 475 | "verbose": "crime", 476 | "symmetric": false 477 | }, 478 | "Victim": { 479 | "short": "Victim", 480 | "verbose": "victim", 481 | "symmetric": false 482 | }, 483 | "Price": { 484 | "short": "Price", 485 | "verbose": "price", 486 | "symmetric": false 487 | }, 488 | "Time-Holds": { 489 | "short": "Time-Holds", 490 | "verbose": "time holds", 491 | "symmetric": false 492 | }, 493 | "Time-Starting": { 494 | "short": "Time-Starting", 495 | "verbose": "time starting", 496 | "symmetric": false 497 | }, 498 | "Position": { 499 | "short": "Position", 500 | "verbose": "position", 501 | "symmetric": false 502 | }, 503 | "Defendant": { 504 | "short": "Defendant", 505 | "verbose": "defendant", 506 | "symmetric": false 507 | }, 508 | "Giver": { 509 | "short": "Giver", 510 | "verbose": "giver", 511 | "symmetric": false 512 | }, 513 | "Attacker": { 514 | "short": "Attacker", 515 | "verbose": "attacker", 516 | "symmetric": false 517 | }, 518 | "Plaintiff": { 519 | "short": "Plaintiff", 520 | "verbose": "plaintiff", 521 | "symmetric": false 522 | }, 523 | "Time-After": { 524 | "short": "Time-After", 525 | "verbose": "time after", 526 | "symmetric": false 527 | } 528 | } 529 | } -------------------------------------------------------------------------------- /src/dataset_processing/arguments.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/amazon-science/tanl 2 | from dataclasses import dataclass, field 3 | from typing import Optional 4 | import transformers 5 | 6 | 7 | @dataclass 8 | class TrainingArguments(transformers.TrainingArguments): 9 | """ 10 | Arguments for the Trainer. 11 | """ 12 | output_dir: str = field( 13 | default='experiments', 14 | metadata={"help": "The output directory where the results and model weights will be written."} 15 | ) 16 | 17 | zero_shot: bool = field( 18 | default=False, 19 | metadata={"help": "Zero-shot setting"} 20 | ) 21 | 22 | 23 | @dataclass 24 | class ModelArguments: 25 | """ 26 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 27 | """ 28 | 29 | model_name_or_path: Optional[str] = field( 30 | default=None, metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 31 | ) 32 | 33 | config_name: Optional[str] = field( 34 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 35 | ) 36 | 37 | tokenizer_name: Optional[str] = field( 38 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 39 | ) 40 | 41 | cache_dir: Optional[str] = field( 42 | default=None, metadata={"help": "Where do you want to store the pretrained models"} 43 | ) 44 | 45 | 46 | @dataclass 47 | class DataTrainingArguments: 48 | """ 49 | Arguments pertaining to what data we are going to input our model for training and eval. 50 | """ 51 | 52 | datasets: Optional[str] = field( 53 | default=None, 54 | metadata={"help": "Comma separated list of dataset names, for training."} 55 | ) 56 | 57 | data_dir: Optional[str] = field( 58 | default="../../data/", 59 | metadata={"help": "Path to data directory"} 60 | ) 61 | 62 | eval_datasets: Optional[str] = field( 63 | default=None, 64 | metadata={"help": "Comma separated list of dataset names. Defaults to the train datasets."} 65 | ) 66 | 67 | train_split: str = field( 68 | default='train', 69 | metadata={"help": "The datasplit for training. Can be 'train', 'dev', 'test', etc."} 70 | ) 71 | 72 | max_seq_length: int = field( 73 | default=512, 74 | metadata={ 75 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 76 | "than this will be truncated, shorter sequences will be padded." 77 | }, 78 | ) 79 | 80 | max_output_seq_length: Optional[int] = field( 81 | default=None, 82 | metadata={ 83 | "help": "The maximum output sequence length (default is the same as input)" 84 | }, 85 | ) 86 | 87 | overwrite_cache: bool = field( 88 | default=True, metadata={"help": "Overwrite the cached training and evaluation sets"} 89 | ) 90 | 91 | train_subset: float = field( 92 | default=1, metadata={"help": "The portion of training data to use"} 93 | ) 94 | 95 | episodes: str = field( 96 | default='0', metadata={"help": "Episode indices -- a single number such as 3 or an interval such as 1-4\n" 97 | "The index is also used as random seeds and this setting is therefore used to " 98 | "repeat multiple experiments."} 99 | ) 100 | 101 | num_beams: int = field( 102 | default=None, 103 | metadata={"help": "Number of beams for beam search during generation (only affects evaluation)"} 104 | ) 105 | 106 | max_seq_length_eval: int = field( 107 | default=None, 108 | metadata={ 109 | "help": "Maximum input sequence length at evaluation time (default is equal to max_seq_length)" 110 | }, 111 | ) 112 | 113 | max_output_seq_length_eval: int = field( 114 | default=None, 115 | metadata={ 116 | "help": "The maximum output sequence length at evaluation time (default is the same as input)" 117 | }, 118 | ) 119 | 120 | input_format: str = field( 121 | default=None, metadata={"help": "Input format"} 122 | ) 123 | 124 | output_format: str = field( 125 | default=None, metadata={"help": "Output format"} 126 | ) 127 | 128 | multitask: bool = field( 129 | default=False, metadata={"help": "If true, each input sentence is prepended with the dataset name"} 130 | ) 131 | 132 | 133 | num_shots: int = field( 134 | default=None, metadata={"help": "number of shots (few-shot argument for the FewRel dataset)"} 135 | ) 136 | 137 | num_ways: int = field( 138 | default=None, metadata={"help": "number of ways (few-shot argument for the FewRel dataset)"} 139 | ) 140 | 141 | num_query: int = field( 142 | default=5, metadata={"help": "number of query examples (few-shot argument for the FewRel dataset)"} 143 | ) 144 | 145 | 146 | chunk_size: int = field( 147 | default=128, metadata={"help": "Size of document chunks"} 148 | ) 149 | 150 | chunk_overlap: int = field( 151 | default=64, metadata={"help": "Size of overlap between consecutive chunks"} 152 | ) 153 | 154 | chunk_size_eval: int = field( 155 | default=None, metadata={"help": "Size of document chunks during evaluation (default is equal to chunk_size)"} 156 | ) 157 | 158 | chunk_overlap_eval: int = field( 159 | default=None, metadata={"help": "Size of overlap between consecutive chunks during evaluation " 160 | "(default is equal to chunk_overlap)"} 161 | ) 162 | 163 | eval_nll: bool = field( 164 | default=False, metadata={"help": "Evaluate using NLL (only applicable to certain datasets)"} 165 | ) 166 | -------------------------------------------------------------------------------- /src/dataset_processing/base_dataset.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/amazon-science/tanl 2 | import os 3 | import logging 4 | import random 5 | from typing import Dict, Generator, Tuple, List 6 | from abc import ABC, abstractmethod 7 | import torch 8 | from torch.utils.data import DataLoader 9 | from torch.utils.data.dataset import Dataset 10 | from tqdm import tqdm 11 | from transformers import PreTrainedTokenizer, torch_distributed_zero_first, default_data_collator 12 | 13 | from arguments import DataTrainingArguments 14 | from input_example import InputFeatures, InputExample 15 | from input_formats import INPUT_FORMATS 16 | from output_formats import OUTPUT_FORMATS 17 | 18 | 19 | class BaseDataset(Dataset, ABC): 20 | """ 21 | Base class for all datasets. 22 | """ 23 | name = None 24 | data_name = None 25 | task_descriptor = None 26 | 27 | default_input_format = 'plain' 28 | default_output_format = None 29 | default_data_dir = 'data' 30 | 31 | num_episodes = 1 32 | 33 | def __init__( 34 | self, 35 | tokenizer: PreTrainedTokenizer, 36 | max_input_length: int, 37 | max_output_length: int, 38 | overwrite_cache: bool = False, 39 | mode: str = 'train', 40 | local_rank: int = -1, 41 | train_subset: float = 1, 42 | seed: int = None, 43 | shuffle: bool = True, 44 | data_args: DataTrainingArguments = None, 45 | is_eval: bool = False, 46 | ): 47 | if seed is not None: 48 | 49 | random.seed(seed) 50 | 51 | self.data_args = data_args 52 | self.tokenizer = tokenizer 53 | 54 | self.max_input_length = max_input_length 55 | self.max_output_length = max_output_length 56 | 57 | self.input_format = INPUT_FORMATS[ 58 | data_args.input_format if data_args.input_format is not None else self.default_input_format 59 | ]() 60 | self.output_format = OUTPUT_FORMATS[ 61 | data_args.output_format if data_args.output_format is not None else self.default_output_format 62 | ]() 63 | 64 | self.data_path = data_args.data_dir if data_args.data_dir is not None else self.default_data_dir 65 | 66 | self.is_eval = is_eval 67 | self.eval_nll = data_args.eval_nll 68 | 69 | cached_data_file = os.path.join( 70 | self.data_dir(), 71 | f"cached_{self.name}_{mode}_{tokenizer.__class__.__name__}_{max_input_length}_{max_output_length}" 72 | f"{'_multitask' if data_args.multitask else ''}.pth" 73 | ) 74 | 75 | with torch_distributed_zero_first(local_rank): 76 | 77 | 78 | 79 | if os.path.exists(cached_data_file) and not overwrite_cache: 80 | self.load_cached_data(cached_data_file) 81 | 82 | else: 83 | self.load_schema() 84 | self.examples = self.load_data(mode=mode, seed=seed) 85 | 86 | 87 | for example in self.examples: 88 | example.dataset = self 89 | 90 | self.features = self.compute_features( 91 | max_input_length=max_input_length, 92 | max_output_length=max_output_length, 93 | multitask=data_args.multitask, 94 | ) 95 | 96 | if local_rank in [-1, 0]: 97 | 98 | self.save_data(cached_data_file) 99 | 100 | 101 | self.indices = list(range(len(self.examples))) 102 | if seed is not None and shuffle: 103 | random.shuffle(self.indices) 104 | 105 | 106 | self.effective_size = round(train_subset * len(self.examples)) 107 | if train_subset != 1: 108 | logging.info(f"Effective dataset size reduced to {self.effective_size} ({train_subset * 100:.0f}%)") 109 | 110 | def __repr__(self): 111 | return f'Dataset {self.name}' 112 | 113 | def __len__(self): 114 | return self.effective_size 115 | 116 | def __getitem__(self, i: int) -> InputFeatures: 117 | return self.features[self.indices[i]] 118 | 119 | def get_example(self, i: int) -> InputExample: 120 | return self.examples[self.indices[i]] 121 | 122 | def data_dir(self): 123 | if self.data_name is not None: 124 | return os.path.join(self.data_path, self.data_name) 125 | else: 126 | return os.path.join(self.data_path, self.name) 127 | 128 | def load_cached_data(self, cached_data_file: str): 129 | d = torch.load(cached_data_file) 130 | self.examples, self.features = d['examples'], d['features'] 131 | 132 | def save_data(self, cached_data_file: str): 133 | torch.save({ 134 | 'examples': self.examples, 135 | 'features': self.features, 136 | }, cached_data_file) 137 | 138 | def load_schema(self): 139 | """ 140 | Load extra dataset information, such as entity/relation types. 141 | """ 142 | pass 143 | 144 | @abstractmethod 145 | def load_data_single_split(self, split: str, seed: int = None) -> List[InputExample]: 146 | """ 147 | Load data for a single split (train, dev, or test). 148 | """ 149 | pass 150 | 151 | def load_data(self, mode: str, seed: int = None) -> List[InputExample]: 152 | """ 153 | Load all data, where 'mode' is a list of comma-separated splits to use. 154 | """ 155 | examples = [] 156 | 157 | if isinstance(mode, str): 158 | splits = mode.split(',') 159 | else: 160 | assert isinstance(mode, (list, tuple)) 161 | splits = mode 162 | 163 | for split in splits: 164 | examples += self.load_data_single_split(split, seed=seed) 165 | 166 | return examples 167 | 168 | def _warn_max_sequence_length(self, max_sequence_length: int, sentences: List[str], name: str): 169 | max_length_needed = max(len(self.tokenizer.tokenize(x)) for x in sentences) 170 | if max_length_needed > max_sequence_length: 171 | logging.warning( 172 | f'Max sequence length is {max_sequence_length} but the longest {name} sequence is ' 173 | f'{max_length_needed} long' 174 | ) 175 | 176 | def compute_features(self, max_input_length: int, max_output_length: int, multitask: bool = False): 177 | input_sentences = [self.input_format.format_input(example, multitask=multitask) for example in self.examples] 178 | output_sentences = [self.output_format.format_output(example) for example in self.examples] 179 | 180 | input_tok = self.tokenizer.batch_encode_plus( 181 | input_sentences, 182 | max_length=max_input_length, 183 | return_tensors='pt', 184 | padding='max_length', 185 | truncation=True, 186 | ) 187 | self._warn_max_sequence_length(max_input_length, input_sentences, "input") 188 | 189 | output_tok = self.tokenizer.batch_encode_plus( 190 | output_sentences, 191 | max_length=max_output_length, 192 | return_tensors='pt', 193 | padding='max_length', 194 | truncation=True, 195 | ) 196 | self._warn_max_sequence_length(max_output_length, output_sentences, "output") 197 | 198 | assert input_tok.input_ids.size(0) == output_tok.input_ids.size(0) 199 | 200 | features = [] 201 | for sentence_input_ids, att_mask, label_input_ids in zip(input_tok.input_ids, input_tok.attention_mask, 202 | output_tok.input_ids): 203 | features.append(InputFeatures( 204 | input_ids=sentence_input_ids.tolist(), 205 | attention_mask=att_mask.tolist(), 206 | label_ids=label_input_ids.tolist() 207 | )) 208 | 209 | return features 210 | 211 | def generate_output_sentences(self, data_args: DataTrainingArguments, model, device, batch_size: int) \ 212 | -> Generator[Tuple[InputExample, str], None, None]: 213 | """ 214 | Generate pairs (example, output_sentence) for evaluation. 215 | """ 216 | test_data_loader = DataLoader( 217 | self, 218 | batch_size=batch_size, 219 | shuffle=False, 220 | collate_fn=default_data_collator, 221 | ) 222 | 223 | for i, inputs in tqdm(enumerate(test_data_loader), total=len(test_data_loader)): 224 | predictions = model.generate( 225 | inputs['input_ids'].to(device), 226 | max_length=data_args.max_output_seq_length_eval, 227 | num_beams=data_args.num_beams, 228 | ) 229 | 230 | for j, (input_ids, label_ids, prediction) in enumerate( 231 | zip(inputs['input_ids'], inputs['labels'], predictions)): 232 | current_id = i * batch_size + j 233 | example = self.get_example(current_id) 234 | output_sentence = self.tokenizer.decode(prediction, skip_special_tokens=True, 235 | clean_up_tokenization_spaces=False) 236 | 237 | yield example, output_sentence 238 | 239 | @abstractmethod 240 | def evaluate_dataset(self, data_args: DataTrainingArguments, model, device, batch_size: int, macro: bool = False) \ 241 | -> Dict[str, float]: 242 | """ 243 | Evaluate model on this dataset, returning the task-relevant metrics. 244 | """ 245 | pass 246 | -------------------------------------------------------------------------------- /src/dataset_processing/config.ini: -------------------------------------------------------------------------------- 1 | [oie_oie2016] 2 | datasets = oie_oie2016 3 | model_name_or_path = t5-base 4 | num_train_epochs = 1 5 | max_seq_length = 256 6 | max_seq_length_eval = 512 7 | per_device_train_batch_size = 4 8 | per_device_eval_batch_size = 4 9 | do_train = True 10 | do_eval = False 11 | do_predict = True 12 | 13 | [oie_nyt] 14 | datasets = oie_nyt 15 | model_name_or_path = t5-base 16 | num_train_epochs = 100 17 | max_seq_length = 256 18 | max_seq_length_eval = 512 19 | per_device_train_batch_size = 4 20 | per_device_eval_batch_size = 4 21 | do_train = True 22 | do_eval = False 23 | do_predict = True 24 | 25 | [oie_web] 26 | datasets = oie_web 27 | model_name_or_path = t5-base 28 | num_train_epochs = 100 29 | max_seq_length = 256 30 | max_seq_length_eval = 512 31 | per_device_train_batch_size = 4 32 | per_device_eval_batch_size = 4 33 | do_train = True 34 | do_eval = True 35 | do_predict = False 36 | 37 | [oie_penn] 38 | datasets = oie_penn 39 | model_name_or_path = t5-base 40 | num_train_epochs = 100 41 | max_seq_length = 256 42 | max_seq_length_eval = 512 43 | per_device_train_batch_size = 4 44 | per_device_eval_batch_size = 4 45 | do_train = True 46 | do_eval = True 47 | do_predict = False 48 | 49 | [conll04] 50 | datasets = conll04 51 | model_name_or_path = t5-base 52 | num_train_epochs = 100 53 | max_seq_length = 256 54 | max_seq_length_eval = 512 55 | train_split = train,dev 56 | per_device_train_batch_size = 4 57 | per_device_eval_batch_size = 4 58 | do_train = True 59 | do_eval = False 60 | do_predict = True 61 | episodes = 1 62 | num_beams = 8 63 | 64 | [conll04_re] 65 | datasets = conll04_re 66 | model_name_or_path = t5-base 67 | num_train_epochs = 100 68 | max_seq_length = 256 69 | max_seq_length_eval = 512 70 | train_split = train,dev 71 | per_device_train_batch_size = 4 72 | per_device_eval_batch_size = 4 73 | do_train = True 74 | do_eval = False 75 | do_predict = True 76 | episodes = 1 77 | num_beams = 8 78 | 79 | [ade] 80 | datasets = ade 81 | model_name_or_path = t5-base 82 | num_train_epochs = 100 83 | max_seq_length = 256 84 | max_seq_length_eval = 512 85 | per_device_train_batch_size = 4 86 | per_device_eval_batch_size = 4 87 | do_train = True 88 | do_eval = False 89 | do_predict = True 90 | episodes = 1 91 | 92 | [ade0] 93 | datasets = ade0 94 | model_name_or_path = t5-base 95 | num_train_epochs = 100 96 | max_seq_length = 256 97 | max_seq_length_eval = 512 98 | per_device_train_batch_size = 4 99 | per_device_eval_batch_size = 4 100 | do_train = True 101 | do_eval = False 102 | do_predict = True 103 | episodes = 1 104 | 105 | [ade1] 106 | datasets = ade1 107 | model_name_or_path = t5-base 108 | num_train_epochs = 100 109 | max_seq_length = 256 110 | max_seq_length_eval = 512 111 | per_device_train_batch_size = 4 112 | per_device_eval_batch_size = 4 113 | do_train = True 114 | do_eval = False 115 | do_predict = True 116 | episodes = 1 117 | 118 | [ade2] 119 | datasets = ade2 120 | model_name_or_path = t5-base 121 | num_train_epochs = 100 122 | max_seq_length = 256 123 | max_seq_length_eval = 512 124 | per_device_train_batch_size = 4 125 | per_device_eval_batch_size = 4 126 | do_train = True 127 | do_eval = False 128 | do_predict = True 129 | episodes = 1 130 | 131 | [ade3] 132 | datasets = ade3 133 | model_name_or_path = t5-base 134 | num_train_epochs = 100 135 | max_seq_length = 256 136 | max_seq_length_eval = 512 137 | per_device_train_batch_size = 4 138 | per_device_eval_batch_size = 4 139 | do_train = True 140 | do_eval = False 141 | do_predict = True 142 | episodes = 1 143 | 144 | [ade4] 145 | datasets = ade4 146 | model_name_or_path = t5-base 147 | num_train_epochs = 100 148 | max_seq_length = 256 149 | max_seq_length_eval = 512 150 | per_device_train_batch_size = 4 151 | per_device_eval_batch_size = 4 152 | do_train = True 153 | do_eval = False 154 | do_predict = True 155 | episodes = 1 156 | 157 | [ade5] 158 | datasets = ade5 159 | model_name_or_path = t5-base 160 | num_train_epochs = 100 161 | max_seq_length = 256 162 | max_seq_length_eval = 512 163 | per_device_train_batch_size = 4 164 | per_device_eval_batch_size = 4 165 | do_train = True 166 | do_eval = False 167 | do_predict = True 168 | episodes = 1 169 | 170 | [ade6] 171 | datasets = ade6 172 | model_name_or_path = t5-base 173 | num_train_epochs = 100 174 | max_seq_length = 256 175 | max_seq_length_eval = 512 176 | per_device_train_batch_size = 4 177 | per_device_eval_batch_size = 4 178 | do_train = True 179 | do_eval = False 180 | do_predict = True 181 | episodes = 1 182 | 183 | [ade7] 184 | datasets = ade7 185 | model_name_or_path = t5-base 186 | num_train_epochs = 100 187 | max_seq_length = 256 188 | max_seq_length_eval = 512 189 | per_device_train_batch_size = 4 190 | per_device_eval_batch_size = 4 191 | do_train = True 192 | do_eval = False 193 | do_predict = True 194 | episodes = 1 195 | 196 | [ade8] 197 | datasets = ade8 198 | model_name_or_path = t5-base 199 | num_train_epochs = 100 200 | max_seq_length = 256 201 | max_seq_length_eval = 512 202 | per_device_train_batch_size = 4 203 | per_device_eval_batch_size = 4 204 | do_train = True 205 | do_eval = False 206 | do_predict = True 207 | episodes = 1 208 | 209 | [ade9] 210 | datasets = ade9 211 | model_name_or_path = t5-base 212 | num_train_epochs = 100 213 | max_seq_length = 256 214 | max_seq_length_eval = 512 215 | per_device_train_batch_size = 4 216 | per_device_eval_batch_size = 4 217 | do_train = True 218 | do_eval = False 219 | do_predict = True 220 | episodes = 1 221 | 222 | [ade_re] 223 | datasets = ade_re 224 | model_name_or_path = t5-base 225 | num_train_epochs = 100 226 | max_seq_length = 256 227 | max_seq_length_eval = 512 228 | per_device_train_batch_size = 4 229 | per_device_eval_batch_size = 4 230 | do_train = True 231 | do_eval = False 232 | do_predict = True 233 | episodes = 1 234 | 235 | [ade_re0] 236 | datasets = ade_re0 237 | model_name_or_path = t5-base 238 | num_train_epochs = 100 239 | max_seq_length = 256 240 | max_seq_length_eval = 512 241 | per_device_train_batch_size = 4 242 | per_device_eval_batch_size = 4 243 | do_train = True 244 | do_eval = False 245 | do_predict = True 246 | episodes = 1 247 | 248 | [ade_re1] 249 | datasets = ade_re1 250 | model_name_or_path = t5-base 251 | num_train_epochs = 100 252 | max_seq_length = 256 253 | max_seq_length_eval = 512 254 | per_device_train_batch_size = 4 255 | per_device_eval_batch_size = 4 256 | do_train = True 257 | do_eval = False 258 | do_predict = True 259 | episodes = 1 260 | 261 | [ade_re2] 262 | datasets = ade_re2 263 | model_name_or_path = t5-base 264 | num_train_epochs = 100 265 | max_seq_length = 256 266 | max_seq_length_eval = 512 267 | per_device_train_batch_size = 4 268 | per_device_eval_batch_size = 4 269 | do_train = True 270 | do_eval = False 271 | do_predict = True 272 | episodes = 1 273 | 274 | [ade_re3] 275 | datasets = ade_re3 276 | model_name_or_path = t5-base 277 | num_train_epochs = 100 278 | max_seq_length = 256 279 | max_seq_length_eval = 512 280 | per_device_train_batch_size = 4 281 | per_device_eval_batch_size = 4 282 | do_train = True 283 | do_eval = False 284 | do_predict = True 285 | episodes = 1 286 | 287 | [ade_re4] 288 | datasets = ade_re4 289 | model_name_or_path = t5-base 290 | num_train_epochs = 100 291 | max_seq_length = 256 292 | max_seq_length_eval = 512 293 | per_device_train_batch_size = 4 294 | per_device_eval_batch_size = 4 295 | do_train = True 296 | do_eval = False 297 | do_predict = True 298 | episodes = 1 299 | 300 | [ade_re5] 301 | datasets = ade_re5 302 | model_name_or_path = t5-base 303 | num_train_epochs = 100 304 | max_seq_length = 256 305 | max_seq_length_eval = 512 306 | per_device_train_batch_size = 4 307 | per_device_eval_batch_size = 4 308 | do_train = True 309 | do_eval = False 310 | do_predict = True 311 | episodes = 1 312 | 313 | [ade_re6] 314 | datasets = ade_re6 315 | model_name_or_path = t5-base 316 | num_train_epochs = 100 317 | max_seq_length = 256 318 | max_seq_length_eval = 512 319 | per_device_train_batch_size = 4 320 | per_device_eval_batch_size = 4 321 | do_train = True 322 | do_eval = False 323 | do_predict = True 324 | episodes = 1 325 | 326 | [ade_re7] 327 | datasets = ade_re7 328 | model_name_or_path = t5-base 329 | num_train_epochs = 100 330 | max_seq_length = 256 331 | max_seq_length_eval = 512 332 | per_device_train_batch_size = 4 333 | per_device_eval_batch_size = 4 334 | do_train = True 335 | do_eval = False 336 | do_predict = True 337 | episodes = 1 338 | 339 | [ade_re8] 340 | datasets = ade_re8 341 | model_name_or_path = t5-base 342 | num_train_epochs = 100 343 | max_seq_length = 256 344 | max_seq_length_eval = 512 345 | per_device_train_batch_size = 4 346 | per_device_eval_batch_size = 4 347 | do_train = True 348 | do_eval = False 349 | do_predict = True 350 | episodes = 1 351 | 352 | [ade_re9] 353 | datasets = ade_re9 354 | model_name_or_path = t5-base 355 | num_train_epochs = 100 356 | max_seq_length = 256 357 | max_seq_length_eval = 512 358 | per_device_train_batch_size = 4 359 | per_device_eval_batch_size = 4 360 | do_train = True 361 | do_eval = False 362 | do_predict = True 363 | episodes = 1 364 | 365 | [nyt] 366 | datasets = nyt 367 | model_name_or_path = t5-base 368 | num_train_epochs = 10 369 | max_seq_length = 256 370 | max_seq_length_eval = 512 371 | per_device_train_batch_size = 4 372 | per_device_eval_batch_size = 4 373 | do_train = True 374 | do_eval = True 375 | do_predict = True 376 | 377 | [nyt_re] 378 | datasets = nyt_re 379 | model_name_or_path = t5-base 380 | num_train_epochs = 10 381 | max_seq_length = 256 382 | max_seq_length_eval = 512 383 | per_device_train_batch_size = 4 384 | per_device_eval_batch_size = 4 385 | do_train = True 386 | do_eval = True 387 | do_predict = True 388 | 389 | [ace2005_joint_er] 390 | datasets = ace2005_joint_er 391 | model_name_or_path = t5-base 392 | num_train_epochs = 10 393 | max_seq_length = 256 394 | max_seq_length_eval = 512 395 | per_device_train_batch_size = 4 396 | per_device_eval_batch_size = 4 397 | do_train = True 398 | do_eval = True 399 | do_predict = True 400 | 401 | [ace2005_joint_er_re] 402 | datasets = ace2005_joint_er_re 403 | model_name_or_path = t5-base 404 | num_train_epochs = 10 405 | max_seq_length = 256 406 | max_seq_length_eval = 512 407 | per_device_train_batch_size = 4 408 | per_device_eval_batch_size = 4 409 | do_train = True 410 | do_eval = True 411 | do_predict = True 412 | 413 | [ace2005_ner] 414 | datasets = ace2005_ner 415 | model_name_or_path = t5-base 416 | num_train_epochs = 50 417 | max_seq_length = 256 418 | max_seq_length_eval = 512 419 | per_device_train_batch_size = 4 420 | per_device_eval_batch_size = 4 421 | do_train = True 422 | do_eval = True 423 | do_predict = True 424 | 425 | [conll03] 426 | datasets = conll03 427 | model_name_or_path = t5-base 428 | num_train_epochs = 10 429 | max_seq_length = 256 430 | max_seq_length_eval = 512 431 | per_device_train_batch_size = 4 432 | per_device_eval_batch_size = 4 433 | do_train = True 434 | do_eval = False 435 | do_predict = True 436 | 437 | [ontonotes] 438 | datasets = ontonotes 439 | model_name_or_path = t5-base 440 | num_train_epochs = 10 441 | max_seq_length = 256 442 | max_seq_length_eval = 256 443 | per_device_train_batch_size = 4 444 | per_device_eval_batch_size = 4 445 | do_train = True 446 | do_eval = True 447 | do_predict = True 448 | 449 | [genia] 450 | datasets = genia 451 | model_name_or_path = t5-base 452 | num_train_epochs = 10 453 | max_seq_length = 256 454 | max_seq_length_eval = 512 455 | per_device_train_batch_size = 4 456 | per_device_eval_batch_size = 4 457 | do_train = True 458 | do_eval = True 459 | do_predict = True 460 | 461 | [multi_dataset_ner] 462 | datasets = conll03,ontonotes,genia,ace2005_ner 463 | model_name_or_path = t5-base 464 | num_train_epochs = 10 465 | max_seq_length = 256 466 | max_seq_length_eval = 512 467 | per_device_train_batch_size = 4 468 | per_device_eval_batch_size = 4 469 | do_train = True 470 | do_eval = True 471 | do_predict = True 472 | multitask = True 473 | 474 | [fewrel_1shot_5way] 475 | datasets = FewRelEpisodic 476 | model_name_or_path = t5-base 477 | tokenizer_name = t5-base 478 | num_train_epochs = 500 479 | max_seq_length = 256 480 | per_device_train_batch_size = 4 481 | do_train = True 482 | do_eval = False 483 | do_predict = True 484 | episodes = 1-10 485 | num_ways = 5 486 | num_shots = 1 487 | num_query = 5 488 | 489 | [fewrel_5shot_5way] 490 | datasets = FewRelEpisodic 491 | model_name_or_path = t5-base 492 | tokenizer_name = t5-base 493 | num_train_epochs = 500 494 | max_seq_length = 256 495 | per_device_train_batch_size = 4 496 | do_train = True 497 | do_eval = False 498 | do_predict = True 499 | episodes = 1-10 500 | num_ways = 5 501 | num_shots = 5 502 | 503 | 504 | [fewrel_1shot_10way] 505 | datasets = FewRelEpisodic 506 | model_name_or_path = t5-base 507 | tokenizer_name = t5-base 508 | num_train_epochs = 500 509 | max_seq_length = 256 510 | per_device_train_batch_size = 4 511 | do_train = True 512 | do_eval = False 513 | do_predict = True 514 | episodes = 1-10 515 | num_ways = 10 516 | num_shots = 1 517 | 518 | [fewrel_5shot_10way] 519 | datasets = FewRelEpisodic 520 | model_name_or_path = t5-base 521 | tokenizer_name = t5-base 522 | num_train_epochs = 500 523 | max_seq_length = 256 524 | per_device_train_batch_size = 4 525 | do_train = True 526 | do_eval = False 527 | do_predict = True 528 | episodes = 1-10 529 | num_ways = 10 530 | num_shots = 5 531 | 532 | [tacred] 533 | datasets = tacred 534 | multitask = True 535 | model_name_or_path = t5-base 536 | num_train_epochs = 10 537 | max_seq_length = 256 538 | train_split = train 539 | per_device_train_batch_size = 4 540 | do_train = True 541 | do_eval = True 542 | do_predict = True 543 | 544 | [conll05_srl] 545 | datasets = conll05_srl 546 | model_name_or_path = t5-base 547 | num_train_epochs = 1 548 | max_seq_length = 256 549 | per_device_train_batch_size = 4 550 | do_train = True 551 | do_eval = True 552 | do_predict = True 553 | 554 | [conll05_srl_brown] 555 | datasets = conll05_srl_brown 556 | model_name_or_path = t5-base 557 | num_train_epochs = 1 558 | max_seq_length = 256 559 | per_device_train_batch_size = 4 560 | do_train = True 561 | do_eval = False 562 | do_predict = True 563 | 564 | [conll05_srl_wsj] 565 | datasets = conll05_srl_wsj 566 | model_name_or_path = t5-base 567 | num_train_epochs = 1 568 | max_seq_length = 256 569 | per_device_train_batch_size = 4 570 | do_train = True 571 | do_eval = False 572 | do_predict = True 573 | 574 | [conll12_srl] 575 | datasets = conll12_srl 576 | model_name_or_path = t5-base 577 | num_train_epochs = 1 578 | max_seq_length = 256 579 | per_device_train_batch_size = 4 580 | do_train = True 581 | do_eval = True 582 | do_predict = True 583 | 584 | [ace2005event_argument] 585 | datasets = ace2005event_argument 586 | model_name_or_path = t5-base 587 | num_train_epochs = 1 588 | max_seq_length = 256 589 | per_device_train_batch_size = 4 590 | do_train = True 591 | do_eval = True 592 | do_predict = True 593 | 594 | [ace2005event_trigger] 595 | datasets = ace2005event_trigger 596 | model_name_or_path = t5-base 597 | num_train_epochs = 1 598 | max_seq_length = 256 599 | per_device_train_batch_size = 4 600 | do_train = True 601 | do_eval = True 602 | do_predict = True 603 | 604 | [ace2005_event] 605 | multitask = True 606 | datasets = ace2005event_argument,ace2005event_trigger 607 | eval_datasets = ace2005event 608 | model_name_or_path = t5-base 609 | num_train_epochs = 10 610 | max_seq_length = 256 611 | max_seq_length_eval = 512 612 | per_device_train_batch_size = 4 613 | per_device_eval_batch_size = 2 614 | do_train = True 615 | do_eval = True 616 | do_predict = True 617 | 618 | [conll12_coref] 619 | datasets = conll12_coref 620 | model_name_or_path = t5-base 621 | num_train_epochs = 1 622 | max_seq_length = 256 623 | max_seq_length_eval = 256 624 | per_device_train_batch_size = 4 625 | per_device_eval_batch_size = 4 626 | do_train = True 627 | do_eval = True 628 | do_predict = True 629 | chunk_size = 128 630 | chunk_overlap = 64 631 | chunk_size_eval = 128 632 | chunk_overlap_eval = 64 633 | 634 | [multi_woz] 635 | datasets = multi_woz 636 | model_name_or_path = t5-base 637 | do_train = True 638 | do_predict = True 639 | num_train_epochs = 10 640 | max_seq_length = 512 641 | per_device_train_batch_size = 3 642 | per_device_eval_batch_size = 2 643 | overwrite_cache = True 644 | 645 | [snips] 646 | datasets = snips 647 | model_name_or_path = t5-base 648 | do_train = True 649 | do_predict = True 650 | do_eval = True 651 | num_train_epochs = 20 652 | max_seq_length = 256 653 | max_seq_length_eval = 512 654 | per_device_train_batch_size = 4 655 | per_device_eval_batch_size = 4 656 | 657 | [atis] 658 | datasets = atis 659 | model_name_or_path = t5-base 660 | do_train = True 661 | do_predict = True 662 | do_eval = True 663 | num_train_epochs = 20 664 | max_seq_length = 256 665 | max_seq_length_eval = 512 666 | per_device_train_batch_size = 4 667 | per_device_eval_batch_size = 4 668 | 669 | [googlere] 670 | datasets = googlere 671 | model_name_or_path = t5-base 672 | do_train = True 673 | do_predict = True 674 | do_eval = True 675 | num_train_epochs = 20 676 | max_seq_length = 256 677 | max_seq_length_eval = 512 678 | per_device_train_batch_size = 4 679 | per_device_eval_batch_size = 4 680 | 681 | [trex] 682 | datasets = trex 683 | model_name_or_path = t5-base 684 | do_train = True 685 | do_predict = True 686 | do_eval = True 687 | num_train_epochs = 20 688 | max_seq_length = 256 689 | max_seq_length_eval = 512 690 | per_device_train_batch_size = 4 691 | per_device_eval_batch_size = 4 -------------------------------------------------------------------------------- /src/dataset_processing/coreference_metrics.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/amazon-science/tanl 2 | from typing import List, Tuple, Dict 3 | import numpy as np 4 | from collections import Counter 5 | from scipy.optimize import linear_sum_assignment 6 | 7 | 8 | MUC = 'muc' 9 | BCUBED = 'b_cubed' 10 | CEAFE = 'ceafe' 11 | 12 | 13 | class CorefAllMetrics(object): 14 | """ 15 | Wrapper for coreference resolution metrics. 16 | """ 17 | 18 | @staticmethod 19 | def _get_mention_to_x(clusters: List[list]) -> dict: 20 | mention_to_x = {} 21 | for cluster in clusters: 22 | for m in cluster: 23 | mention_to_x[m] = tuple(cluster) 24 | return mention_to_x 25 | 26 | def _compute_coref_metrics(self, gold_clusters: List[list], predicted_clusters: List[list]) \ 27 | -> Dict[str, Dict[str, float]]: 28 | """ 29 | Compute all coreference metrics given a list of gold cluster and a list of predicted clusters. 30 | """ 31 | mention_to_predicted = self._get_mention_to_x(predicted_clusters) 32 | mention_to_gold = self._get_mention_to_x(gold_clusters) 33 | result = {} 34 | metric_name_evals = [('muc', Evaluator(muc)), ('b_cubed', Evaluator(b_cubed)), ('ceaf', Evaluator(ceafe))] 35 | 36 | for name, evaluator in metric_name_evals: 37 | evaluator.update(predicted_clusters, gold_clusters, mention_to_predicted, mention_to_gold) 38 | result[name] = { 39 | 'precision': evaluator.get_precision(), 40 | 'recall': evaluator.get_recall(), 41 | 'f1': evaluator.get_f1() 42 | } 43 | 44 | result['average'] = { 45 | 'precision': sum([result[k]['precision'] for k, _ in metric_name_evals]) / len(metric_name_evals), 46 | 'recall': sum([result[k]['recall'] for k, _ in metric_name_evals]) / len(metric_name_evals), 47 | 'f1': sum([result[k]['f1'] for k, _ in metric_name_evals]) / len(metric_name_evals) 48 | } 49 | 50 | return result 51 | 52 | @staticmethod 53 | def _average_nested_dict(list_nested_dict: List[Dict[str, Dict[str, float]]]) -> Dict[str, Dict[str, float]]: 54 | """ 55 | Given a list of 2-level nested dict, compute the average. 56 | """ 57 | result_dict = {} 58 | 59 | 60 | for outer_dict in list_nested_dict: 61 | for key_outer, value_outer in outer_dict.items(): 62 | if key_outer not in result_dict: 63 | result_dict[key_outer] = {} 64 | for key_inner, value_inner in value_outer.items(): 65 | result_dict[key_outer][key_inner] = result_dict[key_outer].get(key_inner, 0.0) + value_inner 66 | 67 | 68 | for key_outer, value_outer in result_dict.items(): 69 | for key_inner, value_inner in value_outer.items(): 70 | result_dict[key_outer][key_inner] = result_dict[key_outer][key_inner] / len(list_nested_dict) 71 | 72 | return result_dict 73 | 74 | def get_all_metrics(self, labels: List[List[List[Tuple[int, int]]]], preds: List[List[List[Tuple[int, int]]]])\ 75 | -> Dict[str, Dict[str, Dict[str, float]]]: 76 | """ 77 | Compute all metrics for coreference resolution. 78 | 79 | In input are given two list of mention groups, for example: 80 | [ 81 | [ 82 | [ 83 | (5, 7), 84 | (11, 19), 85 | ... 86 | ], 87 | ... 88 | ] 89 | ] 90 | """ 91 | assert len(labels) == len(preds) 92 | result = {} 93 | 94 | 95 | gold_clusters = [ 96 | [(i,) + span for span in cluster] for i, clusters in enumerate(labels) for cluster in clusters 97 | ] 98 | predicted_clusters = [ 99 | [(i,) + span for span in cluster] for i, clusters in enumerate(preds) for cluster in clusters 100 | ] 101 | 102 | result['micro'] = self._compute_coref_metrics(gold_clusters, predicted_clusters) 103 | 104 | 105 | doc_metrics = [] 106 | for gold_clusters, predicted_clusters in zip(labels, preds): 107 | doc_metrics.append(self._compute_coref_metrics( 108 | gold_clusters, predicted_clusters 109 | )) 110 | result['macro'] = self._average_nested_dict(doc_metrics) 111 | 112 | return result 113 | 114 | 115 | def f1(p_num, p_den, r_num, r_den, beta=1): 116 | p = 0 if p_den == 0 else p_num / float(p_den) 117 | r = 0 if r_den == 0 else r_num / float(r_den) 118 | return 0 if p + r == 0 else (1 + beta * beta) * p * r / (beta * beta * p + r) 119 | 120 | 121 | class CorefEvaluator(object): 122 | def __init__(self): 123 | self.metric_names = [MUC, BCUBED, CEAFE] 124 | self.evaluators = [Evaluator(m) for m in (muc, b_cubed, ceafe)] 125 | assert len(self.evaluators) == len(self.metric_names) 126 | self.name_to_evaluator = {n: e for n, e in zip(self.metric_names, self.evaluators)} 127 | 128 | def update(self, predicted, gold, mention_to_predicted, mention_to_gold): 129 | for e in self.evaluators: 130 | e.update(predicted, gold, mention_to_predicted, mention_to_gold) 131 | 132 | def get_f1(self): 133 | return sum(e.get_f1() for e in self.evaluators) / len(self.evaluators) 134 | 135 | def get_recall(self): 136 | return sum(e.get_recall() for e in self.evaluators) / len(self.evaluators) 137 | 138 | def get_precision(self): 139 | return sum(e.get_precision() for e in self.evaluators) / len(self.evaluators) 140 | 141 | def get_prf(self): 142 | return self.get_precision(), self.get_recall(), self.get_f1() 143 | 144 | 145 | class Evaluator(object): 146 | def __init__(self, metric, beta=1): 147 | self.p_num = 0 148 | self.p_den = 0 149 | self.r_num = 0 150 | self.r_den = 0 151 | self.metric = metric 152 | self.beta = beta 153 | 154 | def update(self, predicted, gold, mention_to_predicted, mention_to_gold): 155 | if self.metric == ceafe: 156 | pn, pd, rn, rd = self.metric(predicted, gold, mention_to_predicted, mention_to_gold) 157 | else: 158 | pn, pd = self.metric(predicted, mention_to_gold) 159 | rn, rd = self.metric(gold, mention_to_predicted) 160 | 161 | self.p_num += pn 162 | self.p_den += pd 163 | self.r_num += rn 164 | self.r_den += rd 165 | 166 | def get_f1(self): 167 | return f1(self.p_num, self.p_den, self.r_num, self.r_den, beta=self.beta) 168 | 169 | def get_recall(self): 170 | return 0 if self.r_num == 0 else self.r_num / float(self.r_den) 171 | 172 | def get_precision(self): 173 | return 0 if self.p_num == 0 else self.p_num / float(self.p_den) 174 | 175 | def get_prf(self): 176 | return self.get_precision(), self.get_recall(), self.get_f1() 177 | 178 | def get_counts(self): 179 | return self.p_num, self.p_den, self.r_num, self.r_den 180 | 181 | 182 | def evaluate_documents(documents, metric, beta=1): 183 | evaluator = Evaluator(metric, beta=beta) 184 | for document in documents: 185 | evaluator.update(document) 186 | return evaluator.get_precision(), evaluator.get_recall(), evaluator.get_f1() 187 | 188 | 189 | def b_cubed(clusters, mention_to_gold): 190 | num, dem = 0, 0 191 | 192 | for c in clusters: 193 | gold_counts = Counter() 194 | correct = 0 195 | for m in c: 196 | if m in mention_to_gold: 197 | gold_counts[tuple(mention_to_gold[m])] += 1 198 | for c2, count in gold_counts.items(): 199 | correct += count * count 200 | num += correct / float(len(c)) 201 | dem += len(c) 202 | return num, dem 203 | 204 | 205 | def muc(clusters, mention_to_gold): 206 | tp, p = 0, 0 207 | for c in clusters: 208 | p += len(c) - 1 209 | tp += len(c) 210 | linked = set() 211 | for m in c: 212 | if m in mention_to_gold: 213 | linked.add(mention_to_gold[m]) 214 | else: 215 | tp -= 1 216 | tp -= len(linked) 217 | return tp, p 218 | 219 | 220 | def phi4(matrix1, matrix2): 221 | m_sum1 = np.sum(matrix1, axis=1) 222 | m_sum2 = np.sum(matrix2, axis=0) 223 | return 2 * np.dot(matrix1, matrix2) / (np.outer(m_sum1, np.ones_like( 224 | m_sum2)) + np.outer(np.ones_like(m_sum1), m_sum2)) 225 | 226 | 227 | def ceafe(clusters, gold_clusters, mention_to_predicted, mention_to_gold): 228 | key_list = list(set(mention_to_gold.keys()).union( 229 | set(mention_to_predicted.keys()))) 230 | 231 | key_to_ix = {} 232 | for i, k in enumerate(key_list): 233 | key_to_ix[k] = i 234 | 235 | len_key = len(key_list) 236 | pred_matrix = np.zeros((len(clusters), len_key)) 237 | gold_matrix = np.zeros((len(gold_clusters), len_key)) 238 | fill_cluster_to_matrix(clusters, pred_matrix, key_to_ix) 239 | fill_cluster_to_matrix(gold_clusters, gold_matrix, key_to_ix) 240 | scores = phi4(pred_matrix, gold_matrix.transpose()) 241 | row_ind, col_ind = linear_sum_assignment(-scores) 242 | similarity = scores[row_ind, col_ind].sum() 243 | 244 | return similarity, len(clusters), similarity, len(gold_clusters) 245 | 246 | 247 | def fill_cluster_to_matrix(clusters, matrix, key_to_ix): 248 | for i, c in enumerate(clusters): 249 | for m in c: 250 | matrix[i][key_to_ix[m]] = 1 251 | -------------------------------------------------------------------------------- /src/dataset_processing/evaluate.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/amazon-science/tanl 2 | from typing import List, Dict 3 | import torch 4 | import logging 5 | import numpy as np 6 | from transformers import PreTrainedTokenizer 7 | 8 | from arguments import DataTrainingArguments 9 | from datasets import load_dataset 10 | 11 | 12 | def get_avg_results(results: List[dict]) -> dict: 13 | """ 14 | Compute average results and standard deviation from many episodes. 15 | """ 16 | aggregate_results = {'num_episodes': len(results)} 17 | 18 | for key in results[0]: 19 | try: 20 | numbers = np.array([res[key] for res in results]) 21 | aggregate_results[key] = (numbers.mean(), numbers.std()) 22 | 23 | except: 24 | pass 25 | 26 | return aggregate_results 27 | 28 | 29 | TASK_METRIC_MAPPING = { 30 | "ace2005event_argument": [{'key': 'relation_f1_no_type', 'tkey': 'Argument Id F1'}, 31 | {'key': 'relation_f1', 'tkey': 'Argument Cl F1'}] 32 | } 33 | 34 | 35 | def print_results(results: dict): 36 | 37 | header = f"########## ace2005event_argument Evaluation ##########" 38 | print(header) 39 | for metric in TASK_METRIC_MAPPING['ace2005event_argument']: 40 | print(f"{metric['tkey']:20}: {results[metric['key']][0]}") 41 | print("#" * len(header)) 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | def evaluate(model, dataset_name: str, data_args: DataTrainingArguments, tokenizer: PreTrainedTokenizer, split: str, 57 | seed: int, gpu: int, batch_size: int, mode: str = 'default') -> Dict[str, float]: 58 | """ 59 | Evaluate a model on some dataset. 60 | """ 61 | device = torch.device("cuda", gpu) 62 | test_dataset = load_dataset( 63 | dataset_name, data_args, 64 | max_input_length=data_args.max_seq_length_eval, 65 | max_output_length=data_args.max_output_seq_length_eval, 66 | tokenizer=tokenizer, split=split, seed=seed, shuffle=False, is_eval=True, 67 | ) 68 | 69 | return test_dataset.evaluate_dataset( 70 | data_args=data_args, model=model, device=device, batch_size=batch_size, mode=mode, 71 | external=data_args.data_dir + dataset_name + "/" + "test.jsonl.hyps" 72 | ) 73 | -------------------------------------------------------------------------------- /src/dataset_processing/input_example.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/amazon-science/tanl 2 | from dataclasses import dataclass 3 | from typing import List, Optional, Any, Dict, Union 4 | from torch.utils.data.dataset import Dataset 5 | 6 | 7 | @dataclass 8 | class EntityType: 9 | """ 10 | An entity type in a dataset. 11 | """ 12 | short: str = None 13 | natural: str = None 14 | 15 | def __hash__(self): 16 | return hash(self.short) 17 | 18 | 19 | @dataclass 20 | class RelationType: 21 | """ 22 | A relation type in a dataset. 23 | """ 24 | short: str = None 25 | natural: str = None 26 | 27 | def __hash__(self): 28 | return hash(self.short) 29 | 30 | 31 | @dataclass 32 | class Entity: 33 | """ 34 | An entity in a training/test example. 35 | """ 36 | start: int 37 | end: int 38 | type: Optional[EntityType] = None 39 | id: Optional[int] = None 40 | 41 | def to_tuple(self): 42 | return self.type.natural, self.start, self.end 43 | 44 | def __hash__(self): 45 | return hash((self.id, self.start, self.end)) 46 | 47 | 48 | @dataclass 49 | class Relation: 50 | """ 51 | An (asymmetric) relation in a training/test example. 52 | """ 53 | type: RelationType 54 | head: Entity 55 | tail: Entity 56 | 57 | def to_tuple(self): 58 | return self.type.natural, self.head.to_tuple(), self.tail.to_tuple() 59 | 60 | 61 | @dataclass 62 | class Intent: 63 | """ 64 | The intent of an utterance. 65 | """ 66 | short: str = None 67 | natural: str = None 68 | 69 | def __hash__(self): 70 | return hash(self.short) 71 | 72 | 73 | @dataclass 74 | class InputExample: 75 | """ 76 | A single training/test example. 77 | """ 78 | id: str 79 | tokens: List[str] 80 | dataset: Optional[Dataset] = None 81 | 82 | 83 | entities: List[Entity] = None 84 | relations: List[Relation] = None 85 | intent: Optional[Intent] = None 86 | 87 | 88 | triggers: List[Entity] = None 89 | 90 | 91 | sentence_level_entities: List[Entity] = None 92 | 93 | 94 | document_id: str = None 95 | chunk_id: int = None 96 | offset: int = None 97 | groups: List[List[Entity]] = None 98 | 99 | 100 | belief_state: Union[Dict[str, Any], str] = None 101 | utterance_tokens: str = None 102 | 103 | 104 | @dataclass 105 | class CorefDocument: 106 | """ 107 | A document for the coreference resolution task. 108 | It has several input examples corresponding to chunks of the document. 109 | """ 110 | id: str 111 | tokens: List[str] 112 | chunks: List[InputExample] 113 | chunk_centers: List[int] 114 | groups: List[List[Entity]] 115 | 116 | 117 | @dataclass 118 | class InputFeatures: 119 | """ 120 | A single set of features of data. 121 | Property names are the same names as the corresponding inputs to a model. 122 | """ 123 | input_ids: List[int] 124 | attention_mask: List[int] 125 | label_ids: Optional[List[int]] = None 126 | -------------------------------------------------------------------------------- /src/dataset_processing/input_formats.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/amazon-science/tanl 2 | from abc import ABC, abstractmethod 3 | import copy 4 | 5 | from input_example import InputExample 6 | from utils import augment_sentence, get_span 7 | 8 | INPUT_FORMATS = {} 9 | 10 | 11 | def register_input_format(format_class): 12 | INPUT_FORMATS[format_class.name] = format_class 13 | return format_class 14 | 15 | 16 | class BaseInputFormat(ABC): 17 | name = None 18 | 19 | BEGIN_ENTITY_TOKEN = '[' 20 | END_ENTITY_TOKEN = ']' 21 | SEPARATOR_TOKEN = '|' 22 | RELATION_SEPARATOR_TOKEN = '=' 23 | QUERY_SEPARATOR_TOKEN = ':' 24 | 25 | def format_input(self, example: InputExample, multitask=False, task_descriptor=None): 26 | res = self._format_input(example=example) 27 | if multitask: 28 | name = task_descriptor or example.dataset.task_descriptor or example.dataset.name 29 | res = f'{name} {self.QUERY_SEPARATOR_TOKEN} ' + res 30 | return res 31 | 32 | @abstractmethod 33 | def _format_input(self, example: InputExample, name='') -> str: 34 | raise NotImplementedError 35 | 36 | 37 | @register_input_format 38 | class PlainInputFormat(BaseInputFormat): 39 | """ 40 | This format uses the plain sentence as input. 41 | """ 42 | name = 'plain' 43 | 44 | def _format_input(self, example: InputExample) -> str: 45 | return ' '.join(example.tokens) 46 | 47 | class JointERInputFormat(BaseInputFormat): 48 | name = '' 49 | def _format_input(self, example: InputExample) -> str: 50 | return "{} sentence: {}".format(self.name, " ".join(example.tokens)) 51 | 52 | @register_input_format 53 | class Conll04InputFormat(JointERInputFormat): 54 | name = 'conll04' 55 | 56 | @register_input_format 57 | class ADEInputFormat(JointERInputFormat): 58 | name = 'ade' 59 | 60 | @register_input_format 61 | class NYTInputFormat(JointERInputFormat): 62 | name = 'nyt' 63 | 64 | @register_input_format 65 | class ACE2005REInputFormat(JointERInputFormat): 66 | name = 'ace2005_joint_er' 67 | 68 | @register_input_format 69 | class RelationClassificationInputFormat(BaseInputFormat): 70 | """ 71 | Input format for relation classification. 72 | """ 73 | name = 'rc_input' 74 | 75 | ENTITY_SQUARE_BRACKET_LEFT = '' 76 | ENTITY_SQUARE_BRACKET_RIGHT = '' 77 | 78 | def _format_input(self, example: InputExample) -> str: 79 | return "{} Sentence : {}".format(self.name, " ".join(example.tokens)) 80 | 81 | def rc_format_input(self, example: InputExample, name) -> str: 82 | en1_span = [example.entities[0].start, example.entities[0].end] 83 | en2_span = [example.entities[1].start, example.entities[1].end] 84 | words = example.tokens 85 | first, latter, head_first = (en1_span, en2_span, True) if en1_span[0] < en2_span[0] \ 86 | else (en2_span, en1_span, False) 87 | 88 | s = "rc fewrel sentence : " + " ".join(example.tokens) 89 | s += f" The relationship between {get_span(words, en1_span)} and {get_span(words, en2_span)} is" 90 | 91 | return s.strip() 92 | 93 | 94 | @register_input_format 95 | class EventInputFormat(BaseInputFormat): 96 | """ 97 | Input format for event extraction, where an input example contains exactly one trigger. 98 | """ 99 | name = 'ace2005_event_with_trigger' 100 | 101 | def _format_input(self, example: InputExample) -> str: 102 | triggers = example.triggers 103 | assert len(triggers) <= 1 104 | augmentations = [([(entity.type.natural,)], entity.start, entity.end) for entity in triggers] 105 | 106 | return augment_sentence(example.tokens, augmentations, self.BEGIN_ENTITY_TOKEN, self.SEPARATOR_TOKEN, 107 | self.RELATION_SEPARATOR_TOKEN, self.END_ENTITY_TOKEN) 108 | 109 | 110 | @register_input_format 111 | class SRLInput(BaseInputFormat): 112 | """ 113 | Input format for SRL, where the predicate is marked. 114 | """ 115 | name = 'srl_input' 116 | 117 | def _format_input(self, example) -> str: 118 | try: 119 | assert len(example.sentence_level_entities) == 1 120 | start, end = example.sentence_level_entities[0].start, example.sentence_level_entities[0].end 121 | words = copy.copy(example.tokens) 122 | words.insert(end, self.END_ENTITY_TOKEN) 123 | words.insert(start, self.BEGIN_ENTITY_TOKEN) 124 | return ' '.join(words) 125 | except: 126 | return "" 127 | -------------------------------------------------------------------------------- /src/dataset_processing/preprocess_multiwoz/extract_examples.py: -------------------------------------------------------------------------------- 1 | import json 2 | import unicodedata 3 | import string 4 | import re 5 | import random 6 | import time 7 | import math 8 | import ast 9 | from collections import Counter 10 | from collections import OrderedDict 11 | from tqdm import tqdm 12 | import os 13 | import pickle 14 | from random import shuffle 15 | 16 | from fix_label import * 17 | 18 | 19 | EXPERIMENT_DOMAINS = ["hotel", "train", "restaurant", "attraction", "taxi"] 20 | 21 | SLOT_TO_NATURAL = { 22 | "leaveat" : "leave at", 23 | "pricerange" : "price range", 24 | "arriveby" : "arrive by" 25 | } 26 | 27 | 28 | def get_slot_information(ontology): 29 | ontology_domains = dict([(k, v) for k, v in ontology.items() if k.split("-")[0] in EXPERIMENT_DOMAINS]) 30 | SLOTS = [k.replace(" ","").lower() if ("book" not in k) else k.lower() for k in ontology_domains.keys()] 31 | for i in range(len(SLOTS)): 32 | domain, slot = SLOTS[i].split("-") 33 | if slot in SLOT_TO_NATURAL: 34 | slot = SLOT_TO_NATURAL[slot] 35 | SLOTS[i] = domain + "-" + slot 36 | return SLOTS 37 | 38 | 39 | def read_file(file_name, gating_dict, SLOTS, dataset, lang, mem_lang, sequicity, training, max_line = None, args = {"except_domain" : "", "only_domain" : ""}): 40 | """ 41 | Reads examples from train / dev / test files 42 | 43 | Acknowledgement: most of this code is taken from the trade-dst repo (https://github.com/jasonwu0731/trade-dst) 44 | implementation of the function read_langs 45 | """ 46 | print(("Reading from {}".format(file_name))) 47 | data = [] 48 | max_resp_len, max_value_len = 0, 0 49 | domain_counter = {} 50 | with open(file_name) as f: 51 | dials = json.load(f) 52 | cnt_lin = 1 53 | for dial_dict in dials: 54 | dialog_history = "" 55 | last_belief_dict = {} 56 | 57 | for domain in dial_dict["domains"]: 58 | if domain not in EXPERIMENT_DOMAINS: 59 | continue 60 | if domain not in domain_counter.keys(): 61 | domain_counter[domain] = 0 62 | domain_counter[domain] += 1 63 | 64 | 65 | if args["only_domain"] != "" and args["only_domain"] not in dial_dict["domains"]: 66 | continue 67 | if (args["except_domain"] != "" and dataset == "test" and args["except_domain"] not in dial_dict["domains"]) or \ 68 | (args["except_domain"] != "" and dataset != "test" and [args["except_domain"]] == dial_dict["domains"]): 69 | continue 70 | 71 | 72 | for ti, turn in enumerate(dial_dict["dialogue"]): 73 | turn_domain = turn["domain"] 74 | turn_id = turn["turn_idx"] 75 | agent_utt = "" 76 | user_utt = "[User]: {0}".format(turn["transcript"]) 77 | if len(turn["system_transcript"]): 78 | agent_utt = "[Agent]: {0}".format(turn["system_transcript"]) 79 | turn_uttr = agent_utt+" ; "+user_utt 80 | else: 81 | turn_uttr = user_utt 82 | 83 | turn_uttr_strip = turn_uttr.strip() 84 | 85 | dialog_history += (turn_uttr_strip + " ; ") 86 | source_text = dialog_history.strip() 87 | turn_belief_dict = fix_general_label_error(turn["belief_state"], False, SLOTS) 88 | 89 | 90 | slot_temp = SLOTS 91 | if dataset == "train" or dataset == "dev": 92 | if args["except_domain"] != "": 93 | slot_temp = [k for k in SLOTS if args["except_domain"] not in k] 94 | turn_belief_dict = OrderedDict([(k, v) for k, v in turn_belief_dict.items() if args["except_domain"] not in k]) 95 | elif args["only_domain"] != "": 96 | slot_temp = [k for k in SLOTS if args["only_domain"] in k] 97 | turn_belief_dict = OrderedDict([(k, v) for k, v in turn_belief_dict.items() if args["only_domain"] in k]) 98 | else: 99 | if args["except_domain"] != "": 100 | slot_temp = [k for k in SLOTS if args["except_domain"] in k] 101 | turn_belief_dict = OrderedDict([(k, v) for k, v in turn_belief_dict.items() if args["except_domain"] in k]) 102 | elif args["only_domain"] != "": 103 | slot_temp = [k for k in SLOTS if args["only_domain"] in k] 104 | turn_belief_dict = OrderedDict([(k, v) for k, v in turn_belief_dict.items() if args["only_domain"] in k]) 105 | 106 | turn_belief_list = [str(k)+'-'+str(v) for k, v in turn_belief_dict.items()] 107 | for i in range(len(turn_belief_list)): 108 | domain, label, value = turn_belief_list[i].split("-") 109 | if label in SLOT_TO_NATURAL: 110 | label = SLOT_TO_NATURAL[label] 111 | turn_belief_list[i] = domain + "-" + label + "-" + value 112 | 113 | class_label, generate_y, slot_mask, gating_label = [], [], [], [] 114 | start_ptr_label, end_ptr_label = [], [] 115 | for slot in slot_temp: 116 | if slot in turn_belief_dict.keys(): 117 | generate_y.append(turn_belief_dict[slot]) 118 | 119 | if turn_belief_dict[slot] == "dontcare": 120 | gating_label.append(gating_dict["dontcare"]) 121 | elif turn_belief_dict[slot] == "none": 122 | gating_label.append(gating_dict["none"]) 123 | else: 124 | gating_label.append(gating_dict["ptr"]) 125 | 126 | if max_value_len < len(turn_belief_dict[slot]): 127 | max_value_len = len(turn_belief_dict[slot]) 128 | 129 | else: 130 | generate_y.append("none") 131 | gating_label.append(gating_dict["none"]) 132 | 133 | data_detail = { 134 | "ID":dial_dict["dialogue_idx"], 135 | "domains":dial_dict["domains"], 136 | "turn_domain":turn_domain, 137 | "turn_id":turn_id, 138 | "dialog_history":source_text, 139 | "turn_belief":turn_belief_list, 140 | "turn_uttr":turn_uttr_strip 141 | } 142 | data.append(data_detail) 143 | 144 | if max_resp_len < len(source_text.split()): 145 | max_resp_len = len(source_text.split()) 146 | 147 | cnt_lin += 1 148 | if(max_line and cnt_lin>=max_line): 149 | break 150 | 151 | print("domain_counter", domain_counter) 152 | return data, max_resp_len, slot_temp 153 | 154 | 155 | 156 | def extract_dataset_instances(data_dir, task="dst"): 157 | """ 158 | Returns a dictionary of train/dev/test instances for the given task 159 | """ 160 | save_dir = os.path.join(data_dir, "splits") 161 | if not os.path.exists(save_dir): 162 | os.mkdir(save_dir) 163 | 164 | file_train = os.path.join(data_dir, 'train_dials.json') 165 | file_dev = os.path.join(data_dir, 'dev_dials.json') 166 | file_test = os.path.join(data_dir, 'test_dials.json') 167 | 168 | ontology = json.load(open(os.path.join(data_dir, "multi-woz/MULTIWOZ2.1/ontology.json"), 'r')) 169 | ALL_SLOTS = get_slot_information(ontology) 170 | 171 | gating_dict = {"ptr":0, "dontcare": "dontcare", "none": "none"} 172 | args = {"except_domain" : "", "only_domain" : ""} 173 | pair_train, train_max_len, slot_train = read_file(file_train, gating_dict, ALL_SLOTS, "train", None, None, None, None, args=args) 174 | pair_dev, dev_max_len, slot_dev = read_file(file_dev, gating_dict, ALL_SLOTS, "dev", None, None, None, None, args=args) 175 | pair_test, test_max_len, slot_test = read_file(file_test, gating_dict, ALL_SLOTS, "test", None, None, None, None, args=args) 176 | 177 | print("Read %s pairs train" % len(pair_train)) 178 | print("Read %s pairs dev" % len(pair_dev)) 179 | print("Read %s pairs test" % len(pair_test)) 180 | 181 | SLOTS_LIST = [ALL_SLOTS, slot_train, slot_dev, slot_test] 182 | print("[Train Set & Dev Set Slots]: Number is {} in total".format(str(len(SLOTS_LIST[2])))) 183 | print(SLOTS_LIST[2]) 184 | print("[Test Set Slots]: Number is {} in total".format(str(len(SLOTS_LIST[3])))) 185 | print(SLOTS_LIST[3]) 186 | 187 | train_dict = {"split" : "train", "examples" : pair_train, "max_len" : train_max_len, "slots" : slot_train} 188 | dev_dict = {"split" : "dev", "examples" : pair_dev, "max_len" : dev_max_len, "slots" : slot_dev} 189 | test_dict = {"split" : "test", "examples" : pair_test, "max_len" : test_max_len, "slots" : slot_test} 190 | 191 | json.dump(train_dict, open(os.path.join(save_dir, "train.json"), "w")) 192 | json.dump(dev_dict, open(os.path.join(save_dir, "dev.json"), "w")) 193 | json.dump(test_dict, open(os.path.join(save_dir, "test.json"), "w")) 194 | 195 | 196 | if __name__ == "__main__": 197 | extract_dataset_instances() 198 | -------------------------------------------------------------------------------- /src/dataset_processing/preprocess_multiwoz/prepare_multi_woz.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import shutil 5 | 6 | from extract_examples import extract_dataset_instances 7 | 8 | 9 | parser = argparse.ArgumentParser(description='Process Multi-WOZ 2.1 Dataset.') 10 | parser.add_argument('--data-dir', type=str, required=True, help='where to save the pre-processed multi-woz data') 11 | 12 | 13 | 14 | def remove_domains_from_dataset(input_path, output_path, rm_domains): 15 | data = json.load(open(input_path, "r")) 16 | examples = [] 17 | for x in data["examples"]: 18 | rm_x = False 19 | if any([d in x["turn_domain"] for d in rm_domains]): 20 | continue 21 | examples.append(x) 22 | data["examples"] = examples 23 | json.dump(data, open(output_path, "w")) 24 | 25 | 26 | def main(args): 27 | 28 | extract_dataset_instances(args.data_dir) 29 | split_dir = os.path.join(args.data_dir, "splits") 30 | 31 | split_names = ["train", "dev", "test"] 32 | split_paths = [os.path.join(split_dir, "{0}.json".format(split)) for split in split_names] 33 | 34 | 35 | for path, split in zip(split_paths, split_names): 36 | dir_path = os.path.dirname(path) 37 | save_path = os.path.join(dir_path, "multi_woz_2.1_{0}_5_domain.json".format(split)) 38 | print("saving split to:", save_path) 39 | remove_domains_from_dataset( 40 | input_path=path, 41 | output_path=save_path, 42 | rm_domains=["police", "hospital"] 43 | ) 44 | print("removing: ", path) 45 | os.remove(path) 46 | 47 | 48 | if __name__ == "__main__": 49 | args = parser.parse_args() 50 | main(args) 51 | 52 | 53 | -------------------------------------------------------------------------------- /src/dataset_processing/run.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/amazon-science/tanl 2 | import argparse 3 | import configparser 4 | import itertools 5 | import json 6 | import logging 7 | import os 8 | from collections import defaultdict 9 | import torch 10 | from torch.utils.data import DataLoader 11 | from transformers import AutoConfig, AutoTokenizer, HfArgumentParser, AutoModelForSeq2SeqLM, Trainer 12 | 13 | from arguments import ModelArguments, DataTrainingArguments, TrainingArguments 14 | from datasets import load_dataset 15 | from evaluate import evaluate, get_avg_results, print_results 16 | from utils import get_episode_indices 17 | 18 | 19 | def main(): 20 | assert torch.cuda.is_available(), 'CUDA not available' 21 | 22 | 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('job') 25 | parser.add_argument('-c', '--config_file', type=str, default='config.ini', help='configuration file') 26 | parser.add_argument('-e', '--eval', action='store_true', default=False, help='run evaluation only') 27 | parser.add_argument('--evaluate_checkpoints', action='store_true', default=False, 28 | help='evaluate intermediate checkpoints instead of the final model') 29 | parser.add_argument('--evaluate_last_checkpoint', action='store_true', default=False, 30 | help='evaluate the last intermediate checkpoint instead of the final model') 31 | parser.add_argument('--evaluate_checkpoint_in_dir', type=str, default=None, 32 | help='evaluate the checkpoint in the given directory') 33 | parser.add_argument('-a', '--evaluate_all', action='store_true', default=False, 34 | help='evaluate intermediate checkpoints together with the final model') 35 | parser.add_argument('-g', '--gpu', type=int, default=0, help='which GPU to use for evaluation') 36 | parser.add_argument('-v', '--verbose_results', action='store_true', default=False, 37 | help='print results for each evaluation run') 38 | parser.add_argument('-mode', dest='mode', type=str, default='default', help='mode') 39 | parser.add_argument('--data_only', dest='data_only', action='store_true', default=False, 40 | help='Data generation only') 41 | parser.add_argument('--evaluate_only', dest='evaluate_only', action='store_true', default=False, 42 | help='Evaluation only') 43 | parser.add_argument('--debug', dest='debug', action='store_true', default=False, 44 | help='Debug mode (small dev, test set)') 45 | args, remaining_args = parser.parse_known_args() 46 | 47 | 48 | config = configparser.ConfigParser(allow_no_value=False) 49 | config.read(args.config_file) 50 | job = args.job 51 | assert job in config 52 | 53 | 54 | defaults = { 55 | 'overwrite_output_dir': True, 56 | 'overwrite_cache': True, 57 | 'per_device_eval_batch_size': 4, 58 | 'learning_rate': 5e-4, 59 | 'logging_steps': 5000, 60 | 'save_steps': 0, 61 | } 62 | 63 | 64 | defaults.update(dict(config.items(job))) 65 | for key in defaults: 66 | if defaults[key] in ['True', 'False']: 67 | 68 | defaults[key] = config.getboolean(job, key) 69 | if defaults[key] == 'None': 70 | 71 | defaults[key] = None 72 | 73 | if args.eval: 74 | 75 | defaults['do_train'] = False 76 | 77 | 78 | second_parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 79 | second_parser.set_defaults(**defaults) 80 | 81 | model_args, data_args, training_args = second_parser.parse_args_into_dataclasses(remaining_args) 82 | training_args.do_train = training_args.do_train and not args.evaluate_only 83 | training_args.do_eval = training_args.do_eval and not args.evaluate_only 84 | 85 | try: 86 | os.mkdir(training_args.output_dir) 87 | except FileExistsError: 88 | pass 89 | 90 | 91 | if data_args.max_output_seq_length_eval is None: 92 | 93 | data_args.max_output_seq_length_eval = data_args.max_output_seq_length \ 94 | or data_args.max_seq_length_eval \ 95 | or data_args.max_seq_length 96 | 97 | if data_args.max_output_seq_length is None: 98 | 99 | data_args.max_output_seq_length = data_args.max_seq_length 100 | 101 | if data_args.max_seq_length_eval is None: 102 | 103 | data_args.max_seq_length_eval = data_args.max_seq_length 104 | 105 | if data_args.chunk_size_eval is None: 106 | 107 | data_args.chunk_size_eval = data_args.chunk_size 108 | 109 | if data_args.chunk_overlap_eval is None: 110 | 111 | data_args.chunk_overlap_eval = data_args.chunk_overlap 112 | 113 | 114 | 115 | output_dir = os.path.join( 116 | training_args.output_dir, 117 | f'{args.job}' 118 | f'-{model_args.model_name_or_path.split("/")[-1]}' 119 | f'-ep{round(training_args.num_train_epochs)}' 120 | f'-len{data_args.max_seq_length}' 121 | ) 122 | 123 | if data_args.max_output_seq_length != data_args.max_seq_length: 124 | output_dir += f'-{data_args.max_output_seq_length}' 125 | 126 | if training_args.learning_rate != 5e-4: 127 | output_dir += f'-lr{training_args.learning_rate}' 128 | 129 | output_dir += f'-b{training_args.per_device_train_batch_size}' \ 130 | f'-{data_args.train_split}' 131 | 132 | if data_args.chunk_size != 128: 133 | output_dir += f'-chunk{data_args.chunk_size}' 134 | if data_args.chunk_overlap != 64: 135 | output_dir += f'-overlap{data_args.chunk_overlap}' 136 | 137 | if data_args.output_format is not None: 138 | output_dir += f'-{data_args.output_format}' 139 | if data_args.input_format is not None: 140 | output_dir += f'-{data_args.input_format}' 141 | if data_args.train_subset < 1: 142 | output_dir += f'-size{data_args.train_subset:.2f}' 143 | 144 | try: 145 | os.mkdir(output_dir) 146 | except FileExistsError: 147 | pass 148 | 149 | 150 | logging.basicConfig( 151 | filename=os.path.join(output_dir, 'logs.log'), 152 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 153 | datefmt='%Y-%m-%d %H:%M:%S', 154 | level=logging.INFO, 155 | ) 156 | logging.getLogger().addHandler(logging.StreamHandler()) 157 | 158 | 159 | evaluation_output_filename = f'results' 160 | if data_args.num_beams is not None: 161 | evaluation_output_filename += f'-{data_args.num_beams}beams' 162 | if data_args.max_seq_length_eval is not None: 163 | evaluation_output_filename += f'-len{data_args.max_seq_length_eval}' 164 | 165 | 166 | config = AutoConfig.from_pretrained( 167 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 168 | cache_dir=model_args.cache_dir, 169 | ) 170 | 171 | 172 | tokenizer = AutoTokenizer.from_pretrained( 173 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 174 | ) 175 | 176 | 177 | dataset_names = data_args.datasets.split(',') 178 | 179 | 180 | if args.job.startswith('ade'): 181 | episode_indices = [int(args.job[-1])] 182 | else: 183 | episode_indices = [-1] 184 | 185 | 186 | 187 | evaluation_results = defaultdict(list) 188 | for ep_idx in episode_indices: 189 | print() 190 | logging.info(f'Episode {ep_idx} ({len(episode_indices)} episodes total)') 191 | episode_output_dir = os.path.join(output_dir, f'episode{ep_idx}') 192 | 193 | try: 194 | os.mkdir(episode_output_dir) 195 | except FileExistsError: 196 | pass 197 | 198 | logging.info(f'Output directory: {episode_output_dir}') 199 | 200 | training_args.output_dir = episode_output_dir 201 | 202 | 203 | model = None 204 | 205 | 206 | if training_args.do_train: 207 | 208 | datasets = [] 209 | print(dataset_names) 210 | for dataset_name in dataset_names: 211 | logging.info(f'Process dataset {dataset_name} (train)') 212 | dataset = load_dataset( 213 | dataset_name, data_args, split=data_args.train_split, 214 | max_input_length=data_args.max_seq_length, max_output_length=data_args.max_output_seq_length, 215 | tokenizer=tokenizer, seed=ep_idx, train_subset=data_args.train_subset, 216 | ) 217 | 218 | 219 | dataset.preprocess_for_glm(mode=args.mode,dataset=dataset_name,debug=args.debug) 220 | 221 | datasets.append(dataset) 222 | 223 | 224 | if args.data_only: 225 | exit(0) 226 | 227 | 228 | if training_args.local_rank in [-1, 0] and (training_args.do_eval or training_args.do_predict): 229 | 230 | evaluation_splits = [] 231 | if training_args.do_eval: 232 | evaluation_splits.append('dev') 233 | if training_args.do_predict: 234 | evaluation_splits.append('test') 235 | 236 | 237 | evaluation_dirs = [] 238 | 239 | if args.evaluate_checkpoints or args.evaluate_last_checkpoint or \ 240 | args.evaluate_checkpoint_in_dir or args.evaluate_all: 241 | 242 | evaluation_dirs = list(sorted([ 243 | checkpoint_dir 244 | for checkpoint_dir in os.listdir(episode_output_dir) 245 | if checkpoint_dir.startswith('checkpoint-') 246 | ], key=lambda x: int(x[len('checkpoint-'):]))) 247 | if args.evaluate_last_checkpoint: 248 | 249 | evaluation_dirs = [evaluation_dirs[-1]] 250 | elif args.evaluate_checkpoint_in_dir: 251 | assert args.evaluate_checkpoint_in_dir in evaluation_dirs, \ 252 | "checkpoint {} does not exist".format(args.evaluate_checkpoint_in_dir) 253 | evaluation_dirs = [args.evaluate_checkpoint_in_dir] 254 | 255 | if args.evaluate_all or (not args.evaluate_checkpoints and not args.evaluate_last_checkpoint): 256 | 257 | evaluation_dirs += [''] 258 | 259 | 260 | if data_args.eval_datasets is None: 261 | eval_dataset_names = dataset_names 262 | else: 263 | eval_dataset_names = data_args.eval_datasets.split(',') 264 | 265 | 266 | for comb in itertools.product(evaluation_splits, evaluation_dirs, eval_dataset_names): 267 | split, evaluation_dir, dataset_name = comb 268 | model_dir = os.path.join(episode_output_dir, evaluation_dir) 269 | 270 | if len(evaluation_dir) > 0: 271 | logging.info(f'Evaluate {evaluation_dir} on {dataset_name} {split}') 272 | else: 273 | logging.info(f'Evaluate on {dataset_name} {split}') 274 | 275 | res = evaluate( 276 | model=model, dataset_name=dataset_name, data_args=data_args, tokenizer=tokenizer, split=split, 277 | seed=ep_idx, batch_size=training_args.per_device_eval_batch_size, gpu=args.gpu, mode=args.mode 278 | ) 279 | 280 | evaluation_results[comb].append(res) 281 | 282 | 283 | if args.verbose_results: 284 | print_results(res) 285 | 286 | 287 | with open( 288 | os.path.join(model_dir, evaluation_output_filename + f'-{dataset_name}-{split}.json'), 'w' 289 | ) as f: 290 | json.dump(res, f, indent=0) 291 | 292 | 293 | for comb, results in evaluation_results.items(): 294 | split, evaluation_dir, dataset_name = comb 295 | 296 | print() 297 | logging.info( 298 | f'Average of {split} results over {len(results)} episodes ({dataset_name} {evaluation_dir}):' 299 | ) 300 | res = get_avg_results(results) 301 | 302 | 303 | print_results(res) 304 | 305 | 306 | filename = evaluation_output_filename + f'-{dataset_name}-{split}' 307 | if len(evaluation_dir) > 0: 308 | filename += '-' 309 | filename += f'{evaluation_dir}.json' 310 | 311 | with open(os.path.join(output_dir, filename), 'w') as f: 312 | json.dump(res, f, indent=0) 313 | 314 | print() 315 | logging.info(f'Model weights and intermediate checkpoints saved in {output_dir}') 316 | 317 | 318 | if __name__ == "__main__": 319 | main() 320 | -------------------------------------------------------------------------------- /src/dataset_processing/utils.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/amazon-science/tanl 2 | import logging 3 | from typing import Tuple, List, Dict 4 | 5 | 6 | def get_episode_indices(episodes_string: str) -> List[int]: 7 | """ 8 | Parse a string such as '2' or '1-5' into a list of integers such as [2] or [1, 2, 3, 4, 5]. 9 | """ 10 | episode_indices = [] 11 | 12 | if episodes_string is not None and episodes_string is not '': 13 | ll = [int(item) for item in episodes_string.split('-')] 14 | 15 | if len(ll) == 1: 16 | episode_indices = ll 17 | 18 | else: 19 | _start, _end = ll 20 | episode_indices = list(range(_start, _end + 1)) 21 | 22 | return episode_indices 23 | 24 | 25 | def expand_tokens(tokens: List[str], augmentations: List[Tuple[List[tuple], int, int]], 26 | entity_tree: Dict[int, List[int]], root: int, 27 | begin_entity_token: str, sep_token: str, relation_sep_token: str, end_entity_token: str) \ 28 | -> List[str]: 29 | """ 30 | Recursively expand the tokens to obtain a sentence in augmented natural language. 31 | 32 | Used in the augment_sentence function below (see the documentation there). 33 | """ 34 | new_tokens = [] 35 | root_start, root_end = augmentations[root][1:] if root >= 0 else (0, len(tokens)) 36 | i = root_start 37 | 38 | for entity_index in entity_tree[root]: 39 | tags, start, end = augmentations[entity_index] 40 | 41 | 42 | new_tokens += tokens[i:start] 43 | 44 | 45 | new_tokens.append(begin_entity_token) 46 | new_tokens += expand_tokens(tokens, augmentations, entity_tree, entity_index, 47 | begin_entity_token, sep_token, relation_sep_token, end_entity_token) 48 | 49 | for tag in tags: 50 | if tag[0]: 51 | 52 | new_tokens.append(sep_token) 53 | new_tokens.append(tag[0]) 54 | 55 | for x in tag[1:]: 56 | new_tokens.append(relation_sep_token) 57 | new_tokens.append(x) 58 | 59 | new_tokens.append(end_entity_token) 60 | i = end 61 | 62 | 63 | new_tokens += tokens[i:root_end] 64 | 65 | return new_tokens 66 | 67 | 68 | def augment_sentence(tokens: List[str], augmentations: List[Tuple[List[tuple], int, int]], begin_entity_token: str, 69 | sep_token: str, relation_sep_token: str, end_entity_token: str) -> str: 70 | """ 71 | Augment a sentence by adding tags in the specified positions. 72 | 73 | Args: 74 | tokens: Tokens of the sentence to augment. 75 | augmentations: List of tuples (tags, start, end). 76 | begin_entity_token: Beginning token for an entity, e.g. '[' 77 | sep_token: Separator token, e.g. '|' 78 | relation_sep_token: Separator token for relations, e.g. '=' 79 | end_entity_token: End token for an entity e.g. ']' 80 | 81 | An example follows. 82 | 83 | tokens: 84 | ['Tolkien', 'was', 'born', 'here'] 85 | 86 | augmentations: 87 | [ 88 | ([('person',), ('born in', 'here')], 0, 1), 89 | ([('location',)], 3, 4), 90 | ] 91 | 92 | output augmented sentence: 93 | [ Tolkien | person | born in = here ] was born [ here | location ] 94 | """ 95 | 96 | augmentations = list(sorted(augmentations, key=lambda z: (z[1], -z[2]))) 97 | 98 | 99 | 100 | root = -1 101 | entity_tree = {root: []} 102 | current_stack = [root] 103 | 104 | for j, x in enumerate(augmentations): 105 | tags, start, end = x 106 | if any(augmentations[k][1] < start < augmentations[k][2] < end for k in current_stack): 107 | 108 | logging.warning(f'Tree structure is not satisfied! Dropping annotation {x}') 109 | continue 110 | 111 | while current_stack[-1] >= 0 and \ 112 | not (augmentations[current_stack[-1]][1] <= start <= end <= augmentations[current_stack[-1]][2]): 113 | current_stack.pop() 114 | 115 | 116 | entity_tree[current_stack[-1]].append(j) 117 | 118 | 119 | current_stack.append(j) 120 | 121 | 122 | entity_tree[j] = [] 123 | 124 | return ' '.join(expand_tokens( 125 | tokens, augmentations, entity_tree, root, begin_entity_token, sep_token, relation_sep_token, end_entity_token 126 | )) 127 | 128 | 129 | def get_span(l: List[str], span: List[int]): 130 | assert len(span) == 2 131 | return " ".join([l[i] for i in range(span[0], span[1]) if i < len(l)]) 132 | 133 | 134 | def get_precision_recall_f1(num_correct, num_predicted, num_gt): 135 | assert 0 <= num_correct <= num_predicted 136 | assert 0 <= num_correct <= num_gt 137 | 138 | precision = num_correct / num_predicted if num_predicted > 0 else 0. 139 | recall = num_correct / num_gt if num_gt > 0 else 0. 140 | f1 = 2. / (1. / precision + 1. / recall) if num_correct > 0 else 0. 141 | 142 | return precision, recall, f1 143 | -------------------------------------------------------------------------------- /src/download_ckpt.sh: -------------------------------------------------------------------------------- 1 | cd ../ckpt 2 | 3 | # download multi-task checkpoints 4 | mkdir MP 5 | cd MP 6 | 7 | echo -e 'Downloading 10B multitask checkpoint. This takes a long time, and you can skip it if you only want to have a try on DeepStruct.' 8 | mkdir 10B 9 | cd ./10B 10 | wget https://huggingface.co/Magolor/deepstruct/resolve/main/hub/MP/10B/mp_rank_00_model_states.pt 11 | cd .. 12 | 13 | mkdir 10B_1 14 | cd ./10B_1 15 | wget https://huggingface.co/Magolor/deepstruct/resolve/main/hub/MP/10B_1/mp_rank_00_model_states.pt 16 | cd .. 17 | -------------------------------------------------------------------------------- /src/glm/config_tasks/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": 4, 3 | "gradient_accumulation_steps": 1, 4 | "steps_per_print": 50, 5 | "gradient_clipping": 1.0, 6 | "zero_optimization": { 7 | "stage": 2, 8 | "contiguous_gradients": false, 9 | "overlap_comm": true, 10 | "reduce_scatter": true, 11 | "reduce_bucket_size": 50000000.0, 12 | "allgather_bucket_size": 50000000.0, 13 | "cpu_offload": true 14 | }, 15 | "zero_allow_untested_optimizer": true, 16 | "fp16": { 17 | "enabled": true, 18 | "loss_scale": 0, 19 | "loss_scale_window": 1000, 20 | "hysteresis": 2, 21 | "min_loss_scale": 1 22 | }, 23 | "optimizer": { 24 | "type": "Adam", 25 | "params": { 26 | "lr": 5e-06, 27 | "betas": [ 28 | 0.9, 29 | 0.95 30 | ], 31 | "eps": 1e-08, 32 | "weight_decay": 0.01 33 | } 34 | }, 35 | "activation_checkpointing": { 36 | "partition_activations": false, 37 | "contiguous_memory_optimization": false 38 | }, 39 | "wall_clock_breakdown": false 40 | } 41 | -------------------------------------------------------------------------------- /src/glm/config_tasks/config_mutliserver.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": 4, 3 | "gradient_accumulation_steps": 1, 4 | "steps_per_print": 50, 5 | "gradient_clipping": 1.0, 6 | "zero_optimization": { 7 | "stage": 2, 8 | "contiguous_gradients": false, 9 | "overlap_comm": true, 10 | "reduce_scatter": true, 11 | "reduce_bucket_size": 50000000.0, 12 | "allgather_bucket_size": 50000000.0, 13 | "cpu_offload": true 14 | }, 15 | "zero_allow_untested_optimizer": true, 16 | "fp16": { 17 | "enabled": true, 18 | "loss_scale": 0, 19 | "loss_scale_window": 1000, 20 | "hysteresis": 2, 21 | "min_loss_scale": 1 22 | }, 23 | "optimizer": { 24 | "type": "Adam", 25 | "params": { 26 | "lr": 5e-06, 27 | "betas": [ 28 | 0.9, 29 | 0.95 30 | ], 31 | "eps": 1e-08, 32 | "weight_decay": 0.01 33 | } 34 | }, 35 | "activation_checkpointing": { 36 | "partition_activations": false, 37 | "contiguous_memory_optimization": false 38 | }, 39 | "wall_clock_breakdown": false 40 | } -------------------------------------------------------------------------------- /src/glm/config_tasks/model_blocklm_10B_pretrain.sh: -------------------------------------------------------------------------------- 1 | 2 | MODEL_TYPE=blocklm-10B 3 | MODEL_ARGS="--block-lm \ 4 | --cloze-eval \ 5 | --task-mask \ 6 | --num-layers 48 \ 7 | --hidden-size 4096 \ 8 | --num-attention-heads 64 \ 9 | --max-position-embeddings 1024 \ 10 | --tokenizer-model-type gpt2 \ 11 | --tokenizer-type GPT2BPETokenizer \ 12 | --load-pretrained ../../ckpt/PRETRAIN" 13 | -------------------------------------------------------------------------------- /src/glm/config_tasks/pretrain.sh: -------------------------------------------------------------------------------- 1 | 2 | EXPERIMENT_NAME=${MODEL_TYPE}-cnndm_org 3 | TASK_NAME=cnn_dm_original 4 | DATA_PATH="${DATA_ROOT}/${TASK_DATASET}" 5 | 6 | TRAIN_ARGS="--epochs 3 \ 7 | --batch-size 4 \ 8 | --lr 1e-5 \ 9 | --lr-decay-style linear \ 10 | --warmup 0.06 \ 11 | --weight-decay 1.0e-1 \ 12 | --label-smoothing 0.1" 13 | COMMON_ARGS="--save-interval 10000 \ 14 | --log-interval 10 \ 15 | --eval-interval 10000 \ 16 | --eval-iters 10000 \ 17 | --eval-epoch 5" 18 | TASK_ARGS="--src-seq-length 512 \ 19 | --tgt-seq-length 512 \ 20 | --min-tgt-length 0 \ 21 | --length-penalty 0.3 \ 22 | --no-repeat-ngram-size 0 \ 23 | --num-beams 8 \ 24 | --select-topk \ 25 | --eval-batch-size 1" 26 | -------------------------------------------------------------------------------- /src/glm/data_utils/tokenization_gpt2.py: -------------------------------------------------------------------------------- 1 | 2 | """Tokenization classes for OpenAI GPT.""" 3 | from __future__ import (absolute_import, division, print_function, 4 | unicode_literals) 5 | 6 | import sys 7 | import json 8 | import logging 9 | import os 10 | import regex as re 11 | from io import open 12 | 13 | try: 14 | from functools import lru_cache 15 | except ImportError: 16 | 17 | 18 | def lru_cache(): 19 | return lambda func: func 20 | 21 | from .file_utils import cached_path 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 26 | 'gpt2': ".pytorch_pretrained_bert/gpt2-vocab.json", 27 | "roberta": ".pytorch_pretrained_bert/roberta-vocab.json" 28 | } 29 | PRETRAINED_MERGES_ARCHIVE_MAP = { 30 | 'gpt2': ".pytorch_pretrained_bert/gpt2-merges.txt", 31 | "roberta": ".pytorch_pretrained_bert/roberta-merges.txt" 32 | } 33 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 34 | 'gpt2': 1024, 35 | } 36 | VOCAB_NAME = 'vocab.json' 37 | MERGES_NAME = 'merges.txt' 38 | SPECIAL_TOKENS_NAME = 'special_tokens.txt' 39 | 40 | @lru_cache() 41 | def bytes_to_unicode(): 42 | """ 43 | Returns list of utf-8 byte and a corresponding list of unicode strings. 44 | The reversible bpe codes work on unicode strings. 45 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 46 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 47 | This is a signficant percentage of your normal, say, 32K bpe vocab. 48 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 49 | And avoids mapping to whitespace/control characters the bpe code barfs on. 50 | """ 51 | _chr = unichr if sys.version_info[0] == 2 else chr 52 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 53 | cs = bs[:] 54 | n = 0 55 | for b in range(2**8): 56 | if b not in bs: 57 | bs.append(b) 58 | cs.append(2**8+n) 59 | n += 1 60 | cs = [_chr(n) for n in cs] 61 | return dict(zip(bs, cs)) 62 | 63 | def get_pairs(word): 64 | """Return set of symbol pairs in a word. 65 | 66 | Word is represented as tuple of symbols (symbols being variable-length strings). 67 | """ 68 | pairs = set() 69 | prev_char = word[0] 70 | for char in word[1:]: 71 | pairs.add((prev_char, char)) 72 | prev_char = char 73 | return pairs 74 | 75 | class GPT2Tokenizer(object): 76 | """ 77 | GPT-2 BPE tokenizer. Peculiarities: 78 | - Byte-level BPE 79 | """ 80 | @classmethod 81 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 82 | """ 83 | Instantiate a PreTrainedBertModel from a pre-trained model file. 84 | Download and cache the pre-trained model file if needed. 85 | """ 86 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 87 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] 88 | merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path] 89 | special_tokens_file = None 90 | else: 91 | vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) 92 | merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) 93 | special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME) 94 | if not os.path.exists(special_tokens_file): 95 | special_tokens_file = None 96 | else: 97 | logger.info("loading special tokens file {}".format(special_tokens_file)) 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | resolved_vocab_file = vocab_file 121 | resolved_merges_file = merges_file 122 | logger.info("loading vocabulary file {}".format(vocab_file)) 123 | logger.info("loading merges file {}".format(merges_file)) 124 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 125 | 126 | 127 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] 128 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 129 | 130 | if special_tokens_file and 'special_tokens' not in kwargs: 131 | special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] 132 | else: 133 | special_tokens = kwargs.pop('special_tokens', []) 134 | tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs) 135 | return tokenizer 136 | 137 | def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None): 138 | self.max_len = max_len if max_len is not None else int(1e12) 139 | self.encoder = json.load(open(vocab_file, encoding='utf-8')) 140 | self.decoder = {v:k for k,v in self.encoder.items()} 141 | self.errors = errors 142 | self.byte_encoder = bytes_to_unicode() 143 | self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} 144 | bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] 145 | bpe_merges = [tuple(merge.split()) for merge in bpe_data] 146 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 147 | self.cache = {} 148 | 149 | 150 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 151 | 152 | self.special_tokens = {} 153 | self.special_tokens_decoder = {} 154 | self.set_special_tokens(special_tokens) 155 | 156 | def __len__(self): 157 | return len(self.encoder) + len(self.special_tokens) 158 | 159 | def set_special_tokens(self, special_tokens): 160 | """ Add a list of additional tokens to the encoder. 161 | The additional tokens are indexed starting from the last index of the 162 | current vocabulary in the order of the `special_tokens` list. 163 | """ 164 | if not special_tokens: 165 | self.special_tokens = {} 166 | self.special_tokens_decoder = {} 167 | return 168 | self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)) 169 | self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()} 170 | logger.info("Special tokens {}".format(self.special_tokens)) 171 | 172 | def bpe(self, token): 173 | if token in self.cache: 174 | return self.cache[token] 175 | word = tuple(token) 176 | pairs = get_pairs(word) 177 | 178 | if not pairs: 179 | return token 180 | 181 | while True: 182 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 183 | if bigram not in self.bpe_ranks: 184 | break 185 | first, second = bigram 186 | new_word = [] 187 | i = 0 188 | while i < len(word): 189 | try: 190 | j = word.index(first, i) 191 | new_word.extend(word[i:j]) 192 | i = j 193 | except: 194 | new_word.extend(word[i:]) 195 | break 196 | 197 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 198 | new_word.append(first+second) 199 | i += 2 200 | else: 201 | new_word.append(word[i]) 202 | i += 1 203 | new_word = tuple(new_word) 204 | word = new_word 205 | if len(word) == 1: 206 | break 207 | else: 208 | pairs = get_pairs(word) 209 | word = ' '.join(word) 210 | self.cache[token] = word 211 | return word 212 | 213 | def tokenize(self, text): 214 | """ Tokenize a string. """ 215 | bpe_tokens = [] 216 | for token in re.findall(self.pat, text): 217 | if sys.version_info[0] == 2: 218 | token = ''.join(self.byte_encoder[ord(b)] for b in token) 219 | else: 220 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 221 | bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) 222 | return bpe_tokens 223 | 224 | def convert_tokens_to_ids(self, tokens): 225 | """ Converts a sequence of tokens into ids using the vocab. """ 226 | ids = [] 227 | if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)): 228 | if tokens in self.special_tokens: 229 | return self.special_tokens[tokens] 230 | else: 231 | return self.encoder.get(tokens, 0) 232 | for token in tokens: 233 | if token in self.special_tokens: 234 | ids.append(self.special_tokens[token]) 235 | else: 236 | ids.append(self.encoder.get(token, 0)) 237 | if len(ids) > self.max_len: 238 | logger.warning( 239 | "Token indices sequence length is longer than the specified maximum " 240 | " sequence length for this OpenAI GPT model ({} > {}). Running this" 241 | " sequence through the model will result in indexing errors".format(len(ids), self.max_len) 242 | ) 243 | return ids 244 | 245 | def convert_ids_to_tokens(self, ids, skip_special_tokens=False): 246 | """Converts a sequence of ids in BPE tokens using the vocab.""" 247 | tokens = [] 248 | for i in ids: 249 | if i in self.special_tokens_decoder: 250 | if not skip_special_tokens: 251 | tokens.append(self.special_tokens_decoder[i]) 252 | else: 253 | tokens.append(self.decoder[i]) 254 | return tokens 255 | 256 | def encode(self, text): 257 | return self.convert_tokens_to_ids(self.tokenize(text)) 258 | 259 | def decode(self, tokens): 260 | text = ''.join([self.decoder[token] for token in tokens]) 261 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) 262 | return text 263 | 264 | def save_vocabulary(self, vocab_path): 265 | """Save the tokenizer vocabulary and merge files to a directory.""" 266 | if not os.path.isdir(vocab_path): 267 | logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) 268 | return 269 | vocab_file = os.path.join(vocab_path, VOCAB_NAME) 270 | merge_file = os.path.join(vocab_path, MERGES_NAME) 271 | special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME) 272 | 273 | with open(vocab_file, 'w', encoding='utf-8') as f: 274 | f.write(json.dumps(self.encoder, ensure_ascii=False)) 275 | 276 | index = 0 277 | with open(merge_file, "w", encoding="utf-8") as writer: 278 | writer.write(u'#version: 0.2\n') 279 | for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): 280 | if index != token_index: 281 | logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." 282 | " Please check that the tokenizer is not corrupted!".format(merge_file)) 283 | index = token_index 284 | writer.write(' '.join(bpe_tokens) + u'\n') 285 | index += 1 286 | 287 | index = len(self.encoder) 288 | with open(special_tokens_file, 'w', encoding='utf-8') as writer: 289 | for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]): 290 | if index != token_index: 291 | logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive." 292 | " Please check that the tokenizer is not corrupted!".format(special_tokens_file)) 293 | index = token_index 294 | writer.write(token + u'\n') 295 | index += 1 296 | 297 | return vocab_file, merge_file, special_tokens_file 298 | -------------------------------------------------------------------------------- /src/glm/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .distributed import PyTorchDistributedDataParallel, DistributedDataParallel 2 | from .modeling_glm import GLMConfig, GLMModel, glm_get_params_for_weight_decay_optimization 3 | from .downstream import GLMForMultiTokenCloze, GLMForMultiTokenClozeFast, GLMForSingleTokenCloze, \ 4 | GLMForSequenceClassification 5 | -------------------------------------------------------------------------------- /src/glm/scripts/ds_finetune_seq2seq.sh: -------------------------------------------------------------------------------- 1 | source ../PATH.sh 2 | DATA_ROOT=$ROOT/data 3 | CHECKPOINT_PATH=$CHECKPOINTROOT 4 | SAVE_PATH=$ROOT/finetune_checkpoints 5 | DATESTR=$(date +"%m-%d-%H-%M") 6 | 7 | source $1 # Model 8 | 9 | TASK_DATASET=$3 10 | source $2 # Task 11 | 12 | NUM_WORKERS=1 13 | NUM_GPUS_PER_WORKER=8 14 | HOST_FILE_PATH="./hostfile" 15 | MP_SIZE=1 16 | MASTER_PORT=$(shuf -n 1 -i 10000-65535) 17 | 18 | OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2" 19 | DISTRIBUTED_ARGS="${OPTIONS_NCCL} deepspeed --hostfile ${HOST_FILE_PATH} --master_port ${MASTER_PORT} --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER}" 20 | 21 | EXPERIMENT_NAME=${EXPERIMENT_NAME}_${DATESTR} 22 | mkdir logs 23 | run_cmd="${DISTRIBUTED_ARGS} finetune_glm.py \ 24 | --deepspeed \ 25 | --deepspeed_config config_tasks/config.json \ 26 | --finetune \ 27 | --experiment-name ${EXPERIMENT_NAME} \ 28 | --task ${TASK_NAME} \ 29 | --task-dataset ${TASK_DATASET} \ 30 | --data-dir ${DATA_PATH} \ 31 | --save ${CHECKPOINT_PATH} \ 32 | --checkpoint-activations \ 33 | --num-workers ${NUM_WORKERS} \ 34 | --no-load-lr-scheduler \ 35 | $MODEL_ARGS \ 36 | $TRAIN_ARGS \ 37 | $COMMON_ARGS \ 38 | $TASK_ARGS \ 39 | --fp16 \ 40 | --model-parallel-size ${MP_SIZE} \ 41 | --overwrite \ 42 | 2>&1 | tee logs/log-${EXPERIMENT_NAME}.txt" 43 | 44 | echo $EXPERIMENT_NAME > "runs/latest_run" 45 | 46 | echo ${run_cmd} 47 | eval ${run_cmd} 48 | 49 | -------------------------------------------------------------------------------- /src/glm/scripts/ds_finetune_seq2seq_multiserver.sh: -------------------------------------------------------------------------------- 1 | source ../PATH.sh 2 | DATA_ROOT=$ROOT/data 3 | CHECKPOINT_PATH=$CHECKPOINTROOT 4 | SAVE_PATH=$ROOT/finetune_checkpoints 5 | DATESTR=$(date +"%m-%d-%H-%M") 6 | 7 | source $1 # Model 8 | 9 | TASK_DATASET=$3 10 | source $2 # Task 11 | 12 | NUM_WORKERS=12 13 | NUM_GPUS_PER_WORKER=8 14 | HOST_FILE_PATH="./hostfile_multiserver" 15 | MP_SIZE=1 16 | MASTER_PORT=$(shuf -n 1 -i 10000-65535) 17 | 18 | OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2" 19 | DISTRIBUTED_ARGS="${OPTIONS_NCCL} deepspeed --hostfile ${HOST_FILE_PATH} --master_port ${MASTER_PORT} --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER}" 20 | 21 | EXPERIMENT_NAME=${EXPERIMENT_NAME}_${DATESTR} 22 | mkdir logs 23 | run_cmd="${DISTRIBUTED_ARGS} finetune_glm.py \ 24 | --deepspeed \ 25 | --deepspeed_config config_tasks/config_multiserver.json \ 26 | --finetune \ 27 | --experiment-name ${EXPERIMENT_NAME} \ 28 | --task ${TASK_NAME} \ 29 | --task-dataset ${TASK_DATASET} \ 30 | --data-dir ${DATA_PATH} \ 31 | --save ${CHECKPOINT_PATH} \ 32 | --checkpoint-activations \ 33 | --num-workers ${NUM_WORKERS} \ 34 | --no-load-lr-scheduler \ 35 | $MODEL_ARGS \ 36 | $TRAIN_ARGS \ 37 | $COMMON_ARGS \ 38 | $TASK_ARGS \ 39 | --fp16 \ 40 | --model-parallel-size ${MP_SIZE} \ 41 | --overwrite \ 42 | 2>&1 | tee logs/log-${EXPERIMENT_NAME}.txt" 43 | 44 | echo $EXPERIMENT_NAME > "runs/latest_run" 45 | 46 | echo ${run_cmd} 47 | eval ${run_cmd} 48 | 49 | -------------------------------------------------------------------------------- /src/glm/scripts/ds_finetune_seq2seq_pretrain.sh: -------------------------------------------------------------------------------- 1 | source ../PATH.sh 2 | DATA_ROOT=$ROOT/data 3 | CHECKPOINT_PATH=$CHECKPOINTROOT 4 | SAVE_PATH=$ROOT/finetune_checkpoints 5 | DATESTR=$(date +"%m-%d-%H-%M") 6 | 7 | source $1 # Model 8 | 9 | TASK_DATASET=$3 10 | source $2 # Task 11 | 12 | NUM_WORKERS=1 13 | NUM_GPUS_PER_WORKER=1 # Set default GPU to 1 14 | HOST_FILE_PATH="./hostfile" 15 | MP_SIZE=1 16 | MASTER_PORT=$(shuf -n 1 -i 10000-65535) 17 | 18 | OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2" 19 | DISTRIBUTED_ARGS="${OPTIONS_NCCL} deepspeed --hostfile ${HOST_FILE_PATH} --master_port ${MASTER_PORT} --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER}" 20 | 21 | EXPERIMENT_NAME=${EXPERIMENT_NAME}_${DATESTR} 22 | mkdir logs 23 | run_cmd="${DISTRIBUTED_ARGS} finetune_glm.py \ 24 | --deepspeed \ 25 | --deepspeed_config config_tasks/config.json \ 26 | --finetune \ 27 | --experiment-name ${EXPERIMENT_NAME} \ 28 | --task ${TASK_NAME} \ 29 | --task-dataset ${TASK_DATASET} \ 30 | --data-dir ${DATA_PATH} \ 31 | --save ${SAVE_PATH} \ 32 | --save-interval 1000 \ 33 | --checkpoint-activations \ 34 | --num-workers ${NUM_WORKERS} \ 35 | --no-load-lr-scheduler \ 36 | $MODEL_ARGS \ 37 | $TRAIN_ARGS \ 38 | $COMMON_ARGS \ 39 | $TASK_ARGS \ 40 | --fp16 \ 41 | --model-parallel-size ${MP_SIZE} \ 42 | --overwrite \ 43 | --epochs 3 \ 44 | 2>&1 | tee logs/log-${EXPERIMENT_NAME}.txt" 45 | 46 | echo $EXPERIMENT_NAME > "runs/latest_run" 47 | 48 | echo ${run_cmd} 49 | eval ${run_cmd} 50 | 51 | -------------------------------------------------------------------------------- /src/glm/tasks/data_utils.py: -------------------------------------------------------------------------------- 1 | 2 | """ Tasks data utility.""" 3 | import copy 4 | import json 5 | import pickle 6 | 7 | import re 8 | from typing import Dict, List, Optional 9 | 10 | import numpy as np 11 | import torch 12 | import torch.utils.data 13 | from torch.utils.data.dataloader import default_collate 14 | 15 | import mpu 16 | 17 | 18 | def clean_text(text): 19 | """Remove new lines and multiple spaces and adjust end of sentence dot.""" 20 | 21 | text = text.replace("\n", " ") 22 | text = re.sub(r'\s+', ' ', text) 23 | for _ in range(3): 24 | text = text.replace(' . ', '. ') 25 | 26 | return text 27 | 28 | 29 | class InputExample(object): 30 | """A raw input example consisting of one or two segments of text and a label""" 31 | 32 | def __init__(self, guid, text_a, text_b=None, label=None, logits=None, meta: Optional[Dict] = None, idx=-1, 33 | num_choices=1): 34 | """ 35 | Create a new InputExample. 36 | 37 | :param guid: a unique textual identifier 38 | :param text_a: the sequence of text 39 | :param text_b: an optional, second sequence of text 40 | :param label: an optional label 41 | :param logits: an optional list of per-class logits 42 | :param meta: an optional dictionary to store arbitrary meta information 43 | :param idx: an optional numeric index 44 | """ 45 | self.guid = guid 46 | self.text_a = text_a 47 | self.text_b = text_b 48 | self.label = label 49 | self.logits = logits 50 | self.idx = idx 51 | self.num_choices = num_choices 52 | self.meta = meta if meta else {} 53 | 54 | def __repr__(self): 55 | return str(self.to_json_string()) 56 | 57 | def to_dict(self): 58 | """Serialize this instance to a Python dictionary.""" 59 | output = copy.deepcopy(self.__dict__) 60 | return output 61 | 62 | def to_json_string(self): 63 | """Serialize this instance to a JSON string.""" 64 | return json.dumps(self.to_dict(), sort_keys=True) 65 | 66 | @staticmethod 67 | def from_json_string(json_str): 68 | """Deserialize this instance from a JSON string.""" 69 | data = json.loads(json_str) 70 | return InputExample(**data) 71 | 72 | @staticmethod 73 | def load_examples(path: str) -> List['InputExample']: 74 | """Load a set of input examples from a file""" 75 | with open(path, 'rb') as fh: 76 | return pickle.load(fh) 77 | 78 | @staticmethod 79 | def save_examples(examples: List['InputExample'], path: str) -> None: 80 | """Save a set of input examples to a file""" 81 | with open(path, 'wb') as fh: 82 | pickle.dump(examples, fh) 83 | 84 | 85 | def num_special_tokens_to_add(text_a_ids, text_b_ids, answer_ids, add_cls, add_sep, add_piece, add_eos=True): 86 | num_tokens = 0 87 | if add_cls: 88 | num_tokens += 1 89 | if text_b_ids and add_sep: 90 | num_tokens += 1 91 | if add_eos: 92 | num_tokens += 1 93 | if not answer_ids and add_piece: 94 | num_tokens += 1 95 | return num_tokens 96 | 97 | 98 | def build_input_from_ids(text_a_ids, text_b_ids, answer_ids, max_seq_length, tokenizer, args=None, add_cls=True, 99 | add_sep=False, add_piece=False, add_eos=True, mask_id=None): 100 | if mask_id is None: 101 | mask_id = tokenizer.get_command('MASK').Id 102 | eos_id = tokenizer.get_command('eos').Id 103 | cls_id = tokenizer.get_command('ENC').Id 104 | sep_id = tokenizer.get_command('sep').Id 105 | ids = [] 106 | types = [] 107 | paddings = [] 108 | 109 | if add_cls: 110 | ids.append(cls_id) 111 | types.append(0) 112 | paddings.append(1) 113 | 114 | len_text_a = len(text_a_ids) 115 | ids.extend(text_a_ids) 116 | types.extend([0] * len_text_a) 117 | paddings.extend([1] * len_text_a) 118 | 119 | if text_b_ids is not None: 120 | 121 | if add_sep: 122 | ids.append(sep_id) 123 | types.append(0) 124 | paddings.append(1) 125 | len_text_b = len(text_b_ids) 126 | ids.extend(text_b_ids) 127 | types.extend([1] * len_text_b) 128 | paddings.extend([1] * len_text_b) 129 | eos_length = 1 if add_eos else 0 130 | 131 | if len(ids) >= max_seq_length - eos_length: 132 | max_seq_length_m1 = max_seq_length - 1 133 | ids = ids[0:max_seq_length_m1] 134 | types = types[0:max_seq_length_m1] 135 | paddings = paddings[0:max_seq_length_m1] 136 | end_type = 0 if text_b_ids is None else 1 137 | if add_eos: 138 | ids.append(eos_id) 139 | types.append(end_type) 140 | paddings.append(1) 141 | sep = len(ids) 142 | target_ids = [0] * len(ids) 143 | loss_masks = [0] * len(ids) 144 | position_ids = list(range(len(ids))) 145 | block_position_ids = [0] * len(ids) 146 | 147 | if add_piece or answer_ids is not None: 148 | sop_id = tokenizer.get_command('sop').Id 149 | mask_position = ids.index(mask_id) if not args.sentinel_token else args.max_position_embeddings 150 | ids.append(sop_id) 151 | types.append(end_type) 152 | paddings.append(1) 153 | position_ids.append(mask_position) 154 | block_position_ids.append(1) 155 | if answer_ids is not None: 156 | len_answer = len(answer_ids) 157 | ids.extend(answer_ids[:-1]) 158 | types.extend([end_type] * (len_answer - 1)) 159 | paddings.extend([1] * (len_answer - 1)) 160 | position_ids.extend([mask_position] * (len_answer - 1)) 161 | if not args.no_block_position: 162 | block_position_ids.extend(range(2, len(answer_ids) + 1)) 163 | else: 164 | block_position_ids.extend([1] * (len(answer_ids) - 1)) 165 | target_ids.extend(answer_ids) 166 | loss_masks.extend([1] * len(answer_ids)) 167 | else: 168 | target_ids.append(0) 169 | loss_masks.append(1) 170 | 171 | padding_length = max_seq_length - len(ids) 172 | if padding_length > 0: 173 | ids.extend([eos_id] * padding_length) 174 | types.extend([eos_id] * padding_length) 175 | paddings.extend([0] * padding_length) 176 | position_ids.extend([0] * padding_length) 177 | block_position_ids.extend([0] * padding_length) 178 | target_ids.extend([0] * padding_length) 179 | loss_masks.extend([0] * padding_length) 180 | if not args.masked_lm: 181 | position_ids = [position_ids, block_position_ids] 182 | return ids, types, paddings, position_ids, sep, target_ids, loss_masks 183 | 184 | 185 | def build_decoder_input(enc_ids, answer_ids, max_seq_length, max_dec_seq_length, tokenizer): 186 | mask_id = tokenizer.get_command('MASK').Id 187 | eos_id = tokenizer.get_command('eos').Id 188 | sop_id = tokenizer.get_command('sop').Id 189 | enc_len = len(enc_ids) 190 | masks = [] 191 | 192 | 193 | 194 | 195 | mask_position = enc_ids.index(mask_id) 196 | len_answer = len(answer_ids) 197 | ids = [sop_id] + answer_ids[:-1] 198 | types = [0] * len_answer 199 | paddings = [1] * len_answer 200 | position_ids = [mask_position] * len_answer 201 | block_position_ids = list(range(1, len_answer + 1)) 202 | target_ids = answer_ids 203 | loss_masks = [1] * len_answer 204 | 205 | padding_length = max_dec_seq_length - len(ids) 206 | if padding_length > 0: 207 | ids.extend([eos_id] * padding_length) 208 | types.extend([0] * padding_length) 209 | paddings.extend([0] * padding_length) 210 | position_ids.extend([0] * padding_length) 211 | block_position_ids.extend([0] * padding_length) 212 | target_ids.extend([0] * padding_length) 213 | loss_masks.extend([0] * padding_length) 214 | position_ids = [position_ids, block_position_ids] 215 | return ids, types, paddings, position_ids, masks, target_ids, loss_masks 216 | 217 | 218 | def build_sample(ids, types=None, paddings=None, positions=None, masks=None, label=None, unique_id=None, target=None, 219 | logit_mask=None, segment_ids=None, prompt_ids=None): 220 | """Convert to numpy and return a sample consumed by the batch producer.""" 221 | 222 | ids_np = np.array(ids, dtype=np.int64) 223 | sample = {'text': ids_np, 'label': int(label)} 224 | if types is not None: 225 | types_np = np.array(types, dtype=np.int64) 226 | sample['types'] = types_np 227 | if paddings is not None: 228 | paddings_np = np.array(paddings, dtype=np.int64) 229 | sample['padding_mask'] = paddings_np 230 | if positions is not None: 231 | positions_np = np.array(positions, dtype=np.int64) 232 | sample['position'] = positions_np 233 | if masks is not None: 234 | masks_np = np.array(masks, dtype=np.int64) 235 | sample['mask'] = masks_np 236 | if target is not None: 237 | target_np = np.array(target, dtype=np.int64) 238 | sample['target'] = target_np 239 | if logit_mask is not None: 240 | logit_mask_np = np.array(logit_mask, dtype=np.int64) 241 | sample['logit_mask'] = logit_mask_np 242 | if segment_ids is not None: 243 | segment_ids = np.array(segment_ids, dtype=np.int64) 244 | sample['segment_id'] = segment_ids 245 | if prompt_ids is not None: 246 | prompt_ids = np.array(prompt_ids, dtype=np.int64) 247 | sample['prompt_pos'] = prompt_ids 248 | if unique_id is not None: 249 | sample['uid'] = unique_id 250 | return sample 251 | 252 | 253 | def build_decoder_sample(sample, dec_ids, dec_position, dec_masks, dec_target, dec_logit_mask): 254 | sample['dec_text'] = np.array(dec_ids) 255 | sample['dec_position'] = np.array(dec_position) 256 | sample['dec_mask'] = np.array(dec_masks) 257 | sample['dec_target'] = np.array(dec_target) 258 | sample['dec_logit_mask'] = np.array(dec_logit_mask) 259 | return sample 260 | 261 | 262 | def my_collate(batch): 263 | new_batch = [{key: value for key, value in sample.items() if key != 'uid'} for sample in batch] 264 | text_list = [sample['text'] for sample in batch] 265 | 266 | def pad_choice_dim(data, choice_num): 267 | if len(data) < choice_num: 268 | data = np.concatenate([data] + [data[0:1]] * (choice_num - len(data))) 269 | return data 270 | 271 | if len(text_list[0].shape) == 2: 272 | choice_nums = list(map(len, text_list)) 273 | max_choice_num = max(choice_nums) 274 | for i, sample in enumerate(new_batch): 275 | for key, value in sample.items(): 276 | if key != 'label': 277 | sample[key] = pad_choice_dim(value, max_choice_num) 278 | else: 279 | sample[key] = value 280 | sample['loss_mask'] = np.array([1] * choice_nums[i] + [0] * (max_choice_num - choice_nums[i]), 281 | dtype=np.int64) 282 | 283 | if 'dec_text' in new_batch[0]: 284 | choice_nums = [len(sample['dec_text']) for sample in new_batch] 285 | if choice_nums.count(choice_nums[0]) != len(choice_nums): 286 | max_choice_num = max(choice_nums) 287 | for i, sample in enumerate(new_batch): 288 | for key, value in sample.items(): 289 | if key.startswith('dec_'): 290 | sample[key] = pad_choice_dim(value, max_choice_num) 291 | sample['loss_mask'] = np.array([1] * choice_nums[i] + [0] * (max_choice_num - choice_nums[i]), 292 | dtype=np.int64) 293 | 294 | new_batch = default_collate(new_batch) 295 | if 'uid' in batch[0]: 296 | uid_list = [sample['uid'] for sample in batch] 297 | new_batch['uid'] = uid_list 298 | return new_batch 299 | 300 | 301 | class FakeDataloader: 302 | def __init__(self, num_iters): 303 | self.num_iters = num_iters 304 | 305 | def __iter__(self): 306 | if self.num_iters is not None: 307 | for _ in range(self.num_iters): 308 | yield None 309 | else: 310 | while True: 311 | yield None 312 | 313 | 314 | def build_data_loader(dataset, batch_size, num_workers, drop_last, shuffle=True, only_rank0=False): 315 | """Data loader. Note that batch-size is the local (per GPU) batch-size.""" 316 | 317 | 318 | if only_rank0: 319 | rank, world_size = 0, 1 320 | else: 321 | world_size = mpu.get_data_parallel_world_size() 322 | rank = mpu.get_data_parallel_rank() 323 | sampler = torch.utils.data.distributed.DistributedSampler( 324 | dataset, num_replicas=world_size, rank=rank, shuffle=shuffle) 325 | 326 | 327 | data_loader = torch.utils.data.DataLoader(dataset, 328 | batch_size=batch_size, 329 | sampler=sampler, 330 | shuffle=False, 331 | num_workers=num_workers, 332 | drop_last=drop_last, 333 | pin_memory=True, 334 | collate_fn=my_collate) 335 | 336 | return data_loader 337 | -------------------------------------------------------------------------------- /src/glm/tasks/eval_utils.py: -------------------------------------------------------------------------------- 1 | 2 | """Evaluation utilities.""" 3 | 4 | import os 5 | import time 6 | import random 7 | import torch 8 | import datetime 9 | 10 | import mpu 11 | from utils import print_rank_0, get_spare_port, debug_finetune_data 12 | from tasks.data_utils import build_data_loader 13 | from finetune_glm import process_batch 14 | from collections import OrderedDict, defaultdict 15 | from typing import List 16 | from tasks.data_utils import InputExample 17 | from sklearn.metrics import f1_score 18 | import numpy as np 19 | 20 | 21 | def accuracy_metric(predictions, labels, examples): 22 | count = 0 23 | num_predictions = len(predictions) 24 | for prediction, label in zip(predictions, labels): 25 | count += prediction == label 26 | return count * 100.0 / num_predictions if num_predictions>0 else 0 27 | 28 | from fuzzywuzzy import fuzz 29 | from tqdm import tqdm 30 | import argparse 31 | 32 | import re 33 | 34 | from evaluate import * 35 | 36 | def f1_macro_metric(predictions, labels, examples, print_results=True): 37 | return (f1_score(labels, predictions, average='macro') if len(predictions)>0 else 0) if len(labels)>0 else 1 38 | 39 | global_tokenizer = None 40 | 41 | def accuracy_func_provider(single_dataset_provider, metric_dict, args, is_test=False, eval_func=None, output_func=None, 42 | only_rank0=True, tokenizer=None): 43 | """Provide function that calculates accuracies.""" 44 | 45 | global global_tokenizer 46 | global_tokenizer = tokenizer 47 | if only_rank0 and torch.distributed.is_initialized() and torch.distributed.get_rank() != 0: 48 | return None 49 | if is_test and not args.eval_valid: 50 | print("using test set...") 51 | datapaths = args.test_data if args.test_data is not None else ['test'] 52 | else: 53 | print("using test dev...") 54 | datapaths = args.valid_data if args.valid_data is not None else ['dev'] 55 | if eval_func is None: 56 | eval_func = multichoice_evaluate 57 | dataloaders = [] 58 | eval_batch_size = args.eval_batch_size if args.eval_batch_size else args.batch_size 59 | for datapath in datapaths: 60 | dataset = single_dataset_provider(datapath) 61 | dataloader = build_data_loader( 62 | dataset, eval_batch_size, num_workers=args.num_workers, 63 | drop_last=False, shuffle=False, only_rank0=only_rank0) 64 | dataloaders.append((dataset.dataset_name, dataloader)) 65 | 66 | def metrics_func(model, epoch, output_predictions=True, summary_writer=None): 67 | print_rank_0('calculating metrics ...') 68 | score_dict = OrderedDict([(key, 0.0) for key in metric_dict]) if isinstance(metric_dict, dict) else { 69 | metric_dict: 0.0} 70 | total = 0 71 | for name, dataloader in dataloaders: 72 | example_dict = None 73 | if hasattr(dataloader.dataset, "examples"): 74 | example_dict = dataloader.dataset.examples 75 | start_time = time.time() 76 | predictions, labels, examples = eval_func(model, dataloader, example_dict, args) 77 | elapsed_time = time.time() - start_time 78 | if output_predictions and torch.distributed.get_rank() == 0: 79 | filename = os.path.join(args.log_dir, name + '.jsonl') 80 | output_func(predictions, examples, filename) 81 | total_count = len(predictions) 82 | single_dict = {key: metric(predictions, labels, examples) for key, metric in metric_dict.items()} 83 | output_str = ' > |epoch: {}| metrics for {}: total {}'.format(epoch, name, total_count) 84 | for key, value in single_dict.items(): 85 | output_str += " {} = {:.4f} %".format(key, value) 86 | if summary_writer is not None and epoch >= 0 and not is_test and len(dataloaders) > 1: 87 | summary_writer.add_scalar(f'Train/valid_{name}_{key}', value, epoch) 88 | output_str += ' elapsed time (sec): {:.3f}'.format(elapsed_time) 89 | if len(dataloaders) > 1: 90 | print_rank_0(output_str) 91 | for key in score_dict: 92 | score_dict[key] += single_dict[key] * total_count 93 | total += total_count 94 | score_dict = {key: score / float(total) if total>0 else 0 for key, score in score_dict.items()} 95 | output_str = ' >> |epoch: {}| overall: total = {}'.format(epoch, total) 96 | for key, score in score_dict.items(): 97 | output_str += " {} = {:.4f}".format(key, score) 98 | if summary_writer is not None and epoch >= 0 and not is_test: 99 | summary_writer.add_scalar(f'Train/valid_{key}', score, epoch) 100 | print_rank_0(output_str) 101 | return score_dict 102 | 103 | return metrics_func 104 | 105 | 106 | segment_length = 10 107 | 108 | 109 | def multichoice_evaluate(model, dataloader, example_dict, args): 110 | """Calculate correct over total answers and return prediction if the 111 | `output_predictions` is true.""" 112 | model.eval() 113 | port = get_spare_port(args) 114 | print_rank_0(f"Using port {port}") 115 | store = torch.distributed.TCPStore(args.master_ip, port, 116 | torch.distributed.get_world_size(), 117 | torch.distributed.get_rank() == 0, datetime.timedelta(seconds=30)) 118 | 119 | 120 | 121 | with torch.no_grad(): 122 | 123 | for _, batch in enumerate(dataloader): 124 | 125 | data = process_batch(batch, args) 126 | if args.pretrained_bert: 127 | tokens, types, labels_, attention_mask = data['text'], data['types'], data['label'], data[ 128 | 'padding_mask'] 129 | inputs = [tokens, types, attention_mask] 130 | elif args.cloze_eval: 131 | tokens, labels_, position_ids = data['text'], data['label'], data['position'] 132 | attention_mask, target_ids, logit_mask = data['mask'], data['target'], data['logit_mask'] 133 | if not args.fast_decode: 134 | inputs = [tokens, position_ids, attention_mask, target_ids, logit_mask] 135 | if args.continuous_prompt: 136 | prompt_pos = data["prompt_pos"] 137 | inputs.append(prompt_pos) 138 | else: 139 | dec_input_ids, dec_position_ids, dec_attention_mask = data['dec_text'], data['dec_position'], data[ 140 | 'dec_mask'] 141 | dec_target_ids, dec_logit_mask = data['dec_target'], data['dec_logit_mask'] 142 | inputs = [tokens, position_ids, attention_mask, dec_input_ids, dec_position_ids, dec_attention_mask, 143 | dec_target_ids, dec_logit_mask] 144 | else: 145 | tokens, labels_, position_ids, attention_mask = data['text'], data['label'], data['position'], data[ 146 | 'mask'] 147 | inputs = [tokens, position_ids, attention_mask] 148 | if len(inputs[0].shape) == 3 and inputs[0].size(1) > segment_length: 149 | logit_list = [] 150 | for i in range((inputs[0].size(1) - 1) // segment_length + 1): 151 | input_batch = [arg[:, i * segment_length: (i + 1) * segment_length] for arg in inputs] 152 | if args.pretrained_bert: 153 | logits = model(*input_batch) 154 | else: 155 | logits, *mems = model(*input_batch) 156 | logit_list.append(logits) 157 | logits = torch.cat(logit_list, dim=1) 158 | elif args.cloze_eval and args.fast_decode: 159 | logit_list = [] 160 | num_choices = inputs[3].size(1) 161 | for i in range((num_choices - 1) // segment_length + 1): 162 | input_batch = inputs[:3] + [arg[:, i * segment_length: (i + 1) * segment_length] for arg in 163 | inputs[3:]] 164 | logits, *mems = model(*input_batch) 165 | logit_list.append(logits) 166 | logits = torch.cat(logit_list, dim=1) 167 | else: 168 | if args.pretrained_bert: 169 | logits = model(*inputs) 170 | else: 171 | logits, *mems = model(*inputs) 172 | if "segment_id" in data: 173 | from torch_scatter import scatter_sum 174 | if "loss_mask" in data: 175 | logits = logits * data["loss_mask"] 176 | logits = scatter_sum(logits, data["segment_id"], dim=1) 177 | elif "loss_mask" in data: 178 | loss_mask = data["loss_mask"] 179 | logits = logits * loss_mask - 10000.0 * (1.0 - loss_mask) 180 | uid_list = batch['uid'] 181 | if isinstance(uid_list, torch.Tensor): 182 | uid_list = uid_list.cpu().numpy().tolist() 183 | predicted = torch.argmax(logits, dim=-1).tolist() 184 | labels = labels_.tolist() 185 | if args.task.lower() == 'wsc': 186 | predicted = [1 if pred == 0 else 0 for pred in predicted] 187 | if mpu.get_model_parallel_rank() == 0: 188 | for uid, prediction, label in zip(uid_list, predicted, labels): 189 | store.set(uid, str((prediction, label))) 190 | model.train() 191 | torch.distributed.barrier() 192 | predictions, labels, examples = [], [], [] 193 | for uid, example in example_dict.items(): 194 | prediction, label = eval(store.get(uid)) 195 | predictions.append(prediction) 196 | labels.append(label) 197 | examples.append(example) 198 | torch.distributed.barrier() 199 | return predictions, labels, examples 200 | -------------------------------------------------------------------------------- /src/glm/tasks/seq2seq/finetune.py: -------------------------------------------------------------------------------- 1 | 2 | """Race.""" 3 | import torch 4 | import mpu 5 | import json 6 | import functools 7 | from tasks.eval_utils import accuracy_func_provider, f1_metric, f1_macro_metric 8 | from finetune_glm import finetune 9 | from pretrain_glm import get_batch 10 | from collections import OrderedDict 11 | from tasks.seq2seq.dataset import Seq2SeqDataset, BlankLMDataset, ExtractionDataset 12 | from tasks.seq2seq.evaluate import rouge_metric, DecoderEvaluater, BlankLMEvaluater 13 | from tasks.superglue.evaluate import squad_exact_match, squad_f1 14 | 15 | global_tokenizer = None 16 | 17 | 18 | def seq2seq_forward_step(data, model, args, timers, mems): 19 | """Forward step.""" 20 | 21 | 22 | if timers is not None: 23 | timers('batch generator').start() 24 | tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data, args) 25 | if timers is not None: 26 | timers('batch generator').stop() 27 | 28 | logits, *mems = model(tokens, position_ids, attention_mask, *mems) 29 | 30 | 31 | losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(), labels) 32 | if args.label_smoothing > 0.0: 33 | epsilon = args.label_smoothing 34 | smooth_loss = -torch.nn.functional.log_softmax(logits, dim=-1).mean(dim=-1) 35 | losses = (1 - epsilon) * losses + epsilon * smooth_loss 36 | loss_mask = loss_mask.reshape(-1) 37 | 38 | loss = torch.sum(losses.reshape(-1) * loss_mask) / loss_mask.sum() 39 | return loss, mems, 'bert' 40 | 41 | 42 | def train_valid_datasets_provider(args, tokenizer): 43 | """Provide train and validation datasets.""" 44 | if args.task.lower() == 'blank': 45 | train_dataset = BlankLMDataset(args, split='train', tokenizer=tokenizer) 46 | valid_dataset = None 47 | elif args.task.lower() == 'extraction': 48 | train_dataset = ExtractionDataset(args, split='train', tokenizer=tokenizer) 49 | valid_dataset = None 50 | else: 51 | train_dataset = Seq2SeqDataset(args, split='train', tokenizer=tokenizer) 52 | valid_dataset = None 53 | global global_tokenizer 54 | global_tokenizer = tokenizer 55 | return train_dataset, valid_dataset 56 | 57 | 58 | def metrics_func_provider(args, tokenizer, is_test): 59 | """Provide metrics callback function.""" 60 | 61 | def single_dataset_provider(split): 62 | if args.task.lower() == 'blank': 63 | return BlankLMDataset(args, split=split, tokenizer=tokenizer) 64 | elif args.task.lower() == 'extraction': 65 | return ExtractionDataset(args, split=split, tokenizer=tokenizer) 66 | else: 67 | return Seq2SeqDataset(args, split=split, tokenizer=tokenizer) 68 | 69 | if args.task.lower() in ['blank', 'extraction']: 70 | evaluater = BlankLMEvaluater(args, tokenizer) 71 | eval_func = evaluater.evaluate 72 | metric_dict = {} 73 | else: 74 | evaluater = DecoderEvaluater(args, tokenizer) 75 | eval_func = evaluater.evaluate 76 | if args.tokenizer_type == "BertWordPieceTokenizer": 77 | dataset = 'cnn_dm' 78 | elif args.task.lower() == 'gigaword': 79 | dataset = 'gigaword' 80 | else: 81 | dataset = 'cnn_dm_org' 82 | if args.task.lower() in ['squad', 'squad_v1']: 83 | metric_dict = {"EM": squad_exact_match, "F1": squad_f1} 84 | else: 85 | 86 | 87 | 88 | metric_dict = {"F1": f1_metric} 89 | 90 | 91 | def output_func(predictions, examples, output_file): 92 | print("[Output]") 93 | if args.task.lower() in ['squad', 'squad_v1']: 94 | with open(output_file, "w", encoding='utf-8') as output: 95 | res = {} 96 | for prediction, example in zip(predictions, examples): 97 | idx = example.idx 98 | if prediction.lower().replace(' ', '') == 'n/a': 99 | prediction = '' 100 | if idx not in res or res[idx] == '': 101 | res[idx] = prediction 102 | json.dump(res, output) 103 | with open(output_file + ".refs", "w", encoding='utf-8') as output: 104 | for prediction, example in zip(predictions, examples): 105 | res = {'id': example.idx, 'pred': prediction, 'gold': example.meta['answers']} 106 | output.write(json.dumps(res) + '\n') 107 | return 108 | with open(output_file + ".hyps", "w", encoding='utf-8') as output: 109 | for prediction in predictions: 110 | output.write(prediction) 111 | output.write("\n") 112 | with open(output_file + ".refs", "w", encoding='utf-8') as output: 113 | for example in examples: 114 | output.write(example.meta["ref"]) 115 | output.write("\n") 116 | if args.task.lower() == 'squad_generation': 117 | with open(output_file + ".source", "w", encoding='utf-8') as output: 118 | for example in examples: 119 | output.write(example.text_a.replace("\n", " ") + " Answer: " + example.meta["answer"]) 120 | output.write("\n") 121 | 122 | return accuracy_func_provider(single_dataset_provider, metric_dict, args, is_test=is_test, eval_func=eval_func, 123 | output_func=output_func, only_rank0=False) 124 | 125 | 126 | def main(args): 127 | if args.src_seq_length > args.max_position_embeddings: 128 | args.max_position_embeddings = args.src_seq_length 129 | if args.task.lower() in ['cnn_dm', 'cnn_dm_original', 'gigaword', 'blank', 'squad_generation', 'xsum', 130 | 'squad', 'squad_v1', 'extraction', 'PRETRAIN']: 131 | finetune(args, train_valid_datasets_provider, {}, end_of_epoch_callback_provider=metrics_func_provider, 132 | forward_step=seq2seq_forward_step) 133 | else: 134 | raise NotImplementedError(args.task) 135 | -------------------------------------------------------------------------------- /src/glm/tasks/superglue/evaluate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Official evaluation script for ReCoRD v1.0. 3 | (Some functions are adopted from the SQuAD evaluation script.) 4 | """ 5 | 6 | from __future__ import print_function 7 | from collections import Counter 8 | import string 9 | import re 10 | from tasks.data_utils import InputExample 11 | from typing import List 12 | import functools 13 | from collections import defaultdict 14 | import unidecode 15 | 16 | 17 | def normalize_answer(s): 18 | """Lower text and remove punctuation, articles and extra whitespace.""" 19 | 20 | def remove_articles(text): 21 | return re.sub(r'\b(a|an|the)\b', ' ', text) 22 | 23 | def white_space_fix(text): 24 | return ' '.join(text.split()) 25 | 26 | def remove_punc(text): 27 | exclude = set(string.punctuation) 28 | return ''.join(ch for ch in text if ch not in exclude) 29 | 30 | def lower(text): 31 | return unidecode.unidecode(text.lower()) 32 | 33 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 34 | 35 | 36 | def f1_score(prediction, ground_truth): 37 | prediction_tokens = normalize_answer(prediction).split() 38 | ground_truth_tokens = normalize_answer(ground_truth).split() 39 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 40 | num_same = sum(common.values()) 41 | if num_same == 0: 42 | return 0 43 | precision = 1.0 * num_same / len(prediction_tokens) 44 | recall = 1.0 * num_same / len(ground_truth_tokens) 45 | f1 = (2 * precision * recall) / (precision + recall) 46 | return f1 47 | 48 | 49 | def exact_match_score(prediction, ground_truth): 50 | return normalize_answer(prediction) == normalize_answer(ground_truth) 51 | 52 | 53 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 54 | if not ground_truths: 55 | return 0.0 56 | scores_for_ground_truths = [] 57 | for ground_truth in ground_truths: 58 | score = metric_fn(prediction, ground_truth) 59 | scores_for_ground_truths.append(score) 60 | return max(scores_for_ground_truths) 61 | 62 | 63 | def qa_evaluate(predictions, labels, examples: List[InputExample], metric): 64 | assert len(examples) == len(predictions) 65 | score = 0.0 66 | for example, prediction in zip(examples, predictions): 67 | ground_truths = example.meta["answers"] 68 | prediction = example.meta["candidates"][prediction] 69 | if ground_truths: 70 | score += metric_max_over_ground_truths(metric, prediction, ground_truths) 71 | score = 100.0 * score / len(predictions) 72 | return score 73 | 74 | 75 | def squad_evaluate(predictions, labels, examples, metric): 76 | assert len(examples) == len(predictions) 77 | score = 0.0 78 | idx2predictions = {} 79 | idx2ground_truths = {} 80 | for example, prediction in zip(examples, predictions): 81 | idx = example.idx 82 | if idx not in idx2predictions: 83 | idx2predictions[idx] = [] 84 | idx2ground_truths[idx] = example.meta["answers"] 85 | idx2predictions[idx].append(prediction) 86 | 87 | for idx, predictions in idx2predictions.items(): 88 | prediction = 'N/A' 89 | for i in range(len(predictions)): 90 | prediction = predictions[i] 91 | if prediction.lower().replace(' ', '') == 'n/a': 92 | prediction = 'N/A' 93 | else: 94 | break 95 | ground_truths = idx2ground_truths[idx] 96 | if len(ground_truths) == 1 and ground_truths[0] == 'N/A': 97 | score += (prediction == 'N/A') 98 | else: 99 | score += metric_max_over_ground_truths(metric, prediction, ground_truths) 100 | score = 100.0 * score / len(idx2predictions) 101 | return score 102 | 103 | 104 | def multirc_em(predictions, labels, examples: List[InputExample]): 105 | """Compute the exact match (EM) for a sequence of predictions and actual labels""" 106 | question_ids = [example.meta["question_idx"] for example in examples] 107 | unique_questions = set(question_ids) 108 | 109 | q_actuals = list(zip(question_ids, labels)) 110 | q_predictions = list(zip(question_ids, predictions)) 111 | 112 | actuals_per_question = defaultdict(list) 113 | predictions_per_question = defaultdict(list) 114 | 115 | for qid, val in q_actuals: 116 | actuals_per_question[qid].append(val) 117 | for qid, val in q_predictions: 118 | predictions_per_question[qid].append(val) 119 | 120 | em = 0 121 | for qid in unique_questions: 122 | if actuals_per_question[qid] == predictions_per_question[qid]: 123 | em += 1 124 | em /= len(unique_questions) 125 | return em 126 | 127 | 128 | qa_exact_match = functools.partial(qa_evaluate, metric=exact_match_score) 129 | qa_f1 = functools.partial(qa_evaluate, metric=f1_score) 130 | 131 | squad_exact_match = functools.partial(squad_evaluate, metric=exact_match_score) 132 | squad_f1 = functools.partial(squad_evaluate, metric=f1_score) 133 | -------------------------------------------------------------------------------- /src/glm/tasks/superglue/finetune.py: -------------------------------------------------------------------------------- 1 | 2 | """Race.""" 3 | 4 | from collections import OrderedDict 5 | from finetune_glm import finetune 6 | from tasks.superglue.dataset import SuperGlueDataset, PROCESSORS, get_output_func 7 | from tasks.superglue.dataset import CLASSIFICATION_DATASETS, MULTI_CHOICE_DATASETS 8 | from tasks.superglue.evaluate import qa_exact_match, qa_f1, multirc_em, squad_exact_match, squad_f1 9 | from tasks.superglue.pvp import PVPS 10 | from tasks.eval_utils import accuracy_func_provider 11 | from tasks.eval_utils import accuracy_metric, f1_macro_metric, f1_metric 12 | 13 | DEFAULT_METRICS = { 14 | "record": [("EM", qa_exact_match), ("F1", qa_f1)], 15 | "copa": [("accuracy", accuracy_metric)], 16 | "rte": [("accuracy", accuracy_metric)], 17 | "boolq": [("accuracy", accuracy_metric)], 18 | "wic": [("accuracy", accuracy_metric)], 19 | "wsc": [("accuracy", accuracy_metric)], 20 | "cb": [("accuracy", accuracy_metric), ("f1-macro", f1_macro_metric)], 21 | "multirc": [("f1a", f1_metric), ("em", multirc_em), ("acc", accuracy_metric)], 22 | "mnli": [("accuracy", accuracy_metric)], 23 | "sst2": [("accuracy", accuracy_metric)], 24 | "qnli": [("accuracy", accuracy_metric)], 25 | "qqp": [("accuracy", accuracy_metric)], 26 | "mrpc": [("accuracy", accuracy_metric)], 27 | "cola": [("accuracy", accuracy_metric)], 28 | } 29 | 30 | 31 | def train_valid_datasets_provider(args, tokenizer, pattern_text=False): 32 | """Provide train and validation datasets.""" 33 | task_name = args.task.lower() 34 | data_dir = args.data_dir 35 | train_dataset = SuperGlueDataset(args, task_name, data_dir, args.seq_length, "train", tokenizer, 36 | pattern_text=pattern_text) 37 | valid_dataset = SuperGlueDataset(args, task_name, data_dir, args.seq_length, "dev", tokenizer, for_train=True, 38 | pattern_text=pattern_text) 39 | 40 | return train_dataset, valid_dataset 41 | 42 | 43 | def metrics_func_provider(args, tokenizer, is_test): 44 | """Privde metrics callback function.""" 45 | 46 | def single_dataset_provider(split): 47 | return SuperGlueDataset(args, args.task.lower(), args.data_dir, args.seq_length, split, tokenizer) 48 | 49 | output_func = get_output_func(args.task.lower(), args) 50 | eval_func = None 51 | if args.task.lower() == 'wsc' and args.cloze_eval and not args.wsc_negative: 52 | from tasks.language_model.finetune import classify_evaluate 53 | eval_func = classify_evaluate 54 | metric_dict = OrderedDict(DEFAULT_METRICS[args.task.lower()]) 55 | return accuracy_func_provider(single_dataset_provider, metric_dict, args, is_test=is_test, eval_func=eval_func, 56 | output_func=output_func, only_rank0=False, tokenizer=tokenizer) 57 | 58 | 59 | def main(args): 60 | model_kwargs = {} 61 | processor = PROCESSORS[args.task.lower()](args) 62 | pvp = PVPS[args.task.lower()](args, None, processor.get_labels(), args.seq_length, 63 | pattern_id=args.pattern_id, is_multi_token=args.multi_token, 64 | num_prompt_tokens=args.num_prompt_tokens) 65 | if args.continuous_prompt: 66 | model_kwargs["spell_length"] = pvp.spell_length 67 | if args.task.lower() == 'wsc' and args.cloze_eval and not args.wsc_negative: 68 | from tasks.language_model.finetune import lm_forward_step 69 | finetune(args, train_valid_datasets_provider, model_kwargs, 70 | end_of_epoch_callback_provider=metrics_func_provider, forward_step=lm_forward_step) 71 | else: 72 | if args.cloze_eval: 73 | multi_token = pvp.is_multi_token 74 | else: 75 | multi_token = args.task.lower() in MULTI_CHOICE_DATASETS 76 | args.multi_token = multi_token 77 | if not multi_token: 78 | model_kwargs["model_type"] = "multiple_choice" if args.cloze_eval else "classification" 79 | model_kwargs["multi_token"] = False 80 | model_kwargs["num_labels"] = len(processor.get_labels()) 81 | else: 82 | model_kwargs["model_type"] = "multiple_choice" 83 | model_kwargs["multi_token"] = True 84 | model_kwargs["num_labels"] = 1 85 | finetune(args, train_valid_datasets_provider, model_kwargs, 86 | end_of_epoch_callback_provider=metrics_func_provider) 87 | -------------------------------------------------------------------------------- /src/glm/zero_shot.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | from new_eval_updated import defaultdict, token_recovery, INSTANCE_OF, fix 3 | from collections import Counter, OrderedDict 4 | import argparse 5 | 6 | import json 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | 12 | parser.add_argument("--dataset", type=str) 13 | parser.add_argument("--path", type=str) 14 | 15 | return parser.parse_args() 16 | 17 | 18 | def generate_mapping(dataset, runs): 19 | dataset_path = "/dataset/fd5061f6/liuxiao/data" 20 | dataset_path = join(dataset_path, dataset) 21 | path = join("/dataset/fd5061f6/liuxiao/deepstruct/glm/runs", runs) 22 | hyps, refs, raw = open(join(path, 'test.jsonl.hyps')).readlines(), \ 23 | open(join(path, 'test.jsonl.refs')).readlines(), \ 24 | open(join(dataset_path,'test.source')).readlines() 25 | 26 | ner_mapping = defaultdict(list) 27 | rel_mapping = defaultdict(list) 28 | 29 | if len(refs)==0: 30 | return 1 31 | 32 | assert (len(hyps) == len(refs)) or (len(hyps) == 0) 33 | assert (len(hyps) == len(raw)) or (len(hyps) == 0) 34 | 35 | scores = []; gt_tails = set() 36 | for hyp, ref, text in (zip(hyps, refs, raw) if len(hyps)>0 else zip(refs, refs, raw)): 37 | result = defaultdict(set); score = defaultdict(float) 38 | text = text.split('Sentence : ')[-1].lower() 39 | text_split = text.split(' ') 40 | hyp = hyp[2:-2]; ref = ref[2:-2] 41 | 42 | hyp_ents, ref_ents, hyp_rels, ref_rels = dict(), dict(), set(), set() 43 | for triple in hyp.split(' ) ( '): 44 | triple = [ 45 | s.lower().strip('[ ').strip(' ]').strip() 46 | for s in triple.lower().strip().split(' ; ') 47 | ] 48 | if len(triple) != 3: 49 | continue 50 | triple[0] = token_recovery(triple[0], text_split) 51 | if triple[1] != INSTANCE_OF: 52 | triple[2] = token_recovery(triple[2], text_split) 53 | if (triple[0] not in text) or ( 54 | (triple[2] not in text) and (triple[1] != INSTANCE_OF) 55 | ): 56 | continue 57 | triple = (fix(triple[0]), triple[1], fix(triple[2])) 58 | if triple[1] != INSTANCE_OF: 59 | hyp_rels.add(tuple(triple)) 60 | else: 61 | hyp_ents[triple[0]] = triple[2] 62 | for triple in ref.split(' ) ( '): 63 | triple = [ 64 | s.lower().strip('[ ').strip(' ]').strip() 65 | for s in triple.lower().strip().split(' ; ') 66 | ] 67 | if len(triple) != 3: 68 | continue 69 | triple = (fix(triple[0]), triple[1], fix(triple[2])) 70 | if triple[1] != INSTANCE_OF: 71 | ref_rels.add(tuple(triple)) 72 | gt_tails.add(triple[1]) 73 | else: 74 | ref_ents[triple[0]] = triple[2] 75 | gt_tails.add(triple[2]) 76 | 77 | 78 | for ent_surface, ner_type in ref_ents.items(): 79 | if ent_surface in hyp_ents: 80 | ner_mapping[hyp_ents[ent_surface]].append(ner_type) 81 | 82 | 83 | 84 | 85 | 86 | for hyp_rel in hyp_rels: 87 | for ref_rel in ref_rels: 88 | if (hyp_rel[0] == ref_rel[0] or ref_rel[0] in hyp_rel[0]) and \ 89 | (hyp_rel[2] == ref_rel[2] or ref_rel[2] in hyp_rel[2]): 90 | rel_mapping[hyp_rel[1]].append(ref_rel[1]) 91 | _ner_mapping, _rel_mapping = ner_mapping.copy(), rel_mapping.copy() 92 | for zs_type in list(ner_mapping.keys()): 93 | ner_mapping[zs_type] = list(Counter(ner_mapping[zs_type]).items()) 94 | 95 | 96 | 97 | ner_mapping[zs_type] = ner_mapping[zs_type][0][0] 98 | for zs_type in list(rel_mapping.keys()): 99 | 100 | 101 | 102 | 103 | rel_mapping[zs_type] = rel_mapping[zs_type][0][0] 104 | 105 | if len(ner_mapping) > 0: 106 | print(json.dumps(ner_mapping, indent=4)) 107 | print("Save to:", join(dataset_path, 'zero_shot_ner_mapping.json')) 108 | json.dump(ner_mapping, open(join(dataset_path, 'zero_shot_ner_mapping.json'), 'w'), indent=4, ensure_ascii=False) 109 | json.dump(_ner_mapping, open(join(dataset_path, 'zero_shot_ner_mapping_stats.json'), 'w'), indent=4, ensure_ascii=False) 110 | if len(rel_mapping) > 0: 111 | print(json.dumps(_rel_mapping, indent=4)) 112 | print("Save to:", join(dataset_path, 'zero_shot_rel_mapping.json')) 113 | json.dump(rel_mapping, open(join(dataset_path, 'zero_shot_rel_mapping.json'), 'w'), indent=4, ensure_ascii=False) 114 | json.dump(_rel_mapping, open(join(dataset_path, 'zero_shot_rel_mapping_stats.json'), 'w'), indent=4, ensure_ascii=False) 115 | 116 | 117 | if __name__ == '__main__': 118 | args = get_args() 119 | generate_mapping(args.dataset, args.path) 120 | -------------------------------------------------------------------------------- /src/manager.py: -------------------------------------------------------------------------------- 1 | from pyheaven import * 2 | import subprocess 3 | import sys 4 | 5 | supported_tasks = [ 6 | "ace2005_joint_er", 7 | "ace2005_joint_er_re", 8 | "ace2005event_trigger", 9 | "ace2005event_argument", 10 | "ade", 11 | "ade0", 12 | "ade1", 13 | "ade2", 14 | "ade3", 15 | "ade4", 16 | "ade5", 17 | "ade6", 18 | "ade7", 19 | "ade8", 20 | "ade9", 21 | "ade_re", 22 | "ade_re0", 23 | "ade_re1", 24 | "ade_re2", 25 | "ade_re3", 26 | "ade_re4", 27 | "ade_re5", 28 | "ade_re6", 29 | "ade_re7", 30 | "ade_re8", 31 | "ade_re9", 32 | "atis", 33 | "conll04", 34 | "conll04_re", 35 | "conll05_srl_brown", 36 | "conll05_srl_wsj", 37 | "conll12_srl", 38 | "multi_woz", 39 | "nyt", 40 | "nyt_re", 41 | "snips" 42 | ] 43 | supported_model_args = { 44 | "model_blocklm_110M": 45 | """\"--block-lm \\ 46 | --cloze-eval \\ 47 | --task-mask \\ 48 | --num-layers 12 \\ 49 | --hidden-size 768 \\ 50 | --num-attention-heads 12 \\ 51 | --max-position-embeddings 1024 \\ 52 | --tokenizer-type BertWordPieceTokenizer \\ 53 | --load-pretrained {0}\"""", 54 | "model_blocklm_220M": 55 | """\"--block-lm \\ 56 | --cloze-eval \\ 57 | --task-mask \\ 58 | --num-layers 14 \\ 59 | --hidden-size 1024 \\ 60 | --num-attention-heads 16 \\ 61 | --max-position-embeddings 1024 \\ 62 | --tokenizer-model-type gpt2 \\ 63 | --tokenizer-type GPT2BPETokenizer \\ 64 | --load-pretrained {0}\"""", 65 | "model_blocklm_2B": 66 | """\"--block-lm \\ 67 | --cloze-eval \\ 68 | --task-mask \\ 69 | --num-layers 36 \\ 70 | --hidden-size 2048 \\ 71 | --num-attention-heads 32 \\ 72 | --max-position-embeddings 1024 \\ 73 | --tokenizer-model-type gpt2 \\ 74 | --tokenizer-type GPT2BPETokenizer \\ 75 | --load-pretrained {0}\"""", 76 | "model_blocklm_10B": 77 | """\"--block-lm \\ 78 | --cloze-eval \\ 79 | --task-mask \\ 80 | --num-layers 48 \\ 81 | --hidden-size 4096 \\ 82 | --num-attention-heads 64 \\ 83 | --max-position-embeddings 1024 \\ 84 | --tokenizer-model-type gpt2 \\ 85 | --tokenizer-type GPT2BPETokenizer \\ 86 | --load-pretrained {0}\"""", 87 | } 88 | supported_modes = [ 89 | "default", 90 | "multi", 91 | "empha", 92 | "task", 93 | ] 94 | train_args = """\"--epochs {0} \\ 95 | --batch-size 4 \\ 96 | --lr 1e-5 \\ 97 | --lr-decay-style linear \\ 98 | --warmup 0.06 \\ 99 | --weight-decay 1.0e-1 \\ 100 | --label-smoothing 0.1\"""" 101 | common_args = """\"--save-interval 10000 \\ 102 | --log-interval 10 \\ 103 | --eval-interval 10000 \\ 104 | --eval-iters 10000 \\ 105 | --eval-epoch 5\"""" 106 | task_args = """\"--src-seq-length {2} \\ 107 | --tgt-seq-length {3} \\ 108 | --min-tgt-length {4} \\ 109 | --length-penalty {1} \\ 110 | --no-repeat-ngram-size 0 \\ 111 | --num-beams {5} \\ 112 | --select-topk \\ 113 | --eval-batch-size 1\"""" 114 | 115 | if __name__ == "__main__": 116 | args = HeavenArguments.from_parser([ 117 | SwitchArgumentDescriptor("multi-server", short="ms"), 118 | 119 | StrArgumentDescriptor("model-checkpoint", short="ckpt", default=None), 120 | LiteralArgumentDescriptor("model-type", short="m", choices=list(supported_model_args.keys()) + ["auto"], 121 | default="auto"), 122 | LiteralArgumentDescriptor("mode", short="md", choices=supported_modes, default="multi"), 123 | SwitchArgumentDescriptor("zero-shot", short="zs"), 124 | 125 | LiteralArgumentDescriptor("task", short="t", choices=supported_tasks, default=None), 126 | IntArgumentDescriptor("task-epochs", short="e", default=50), 127 | IntArgumentDescriptor("num-beams", short="b", default=8), 128 | IntArgumentDescriptor("src-seq-length", short="srcl", default=512), 129 | IntArgumentDescriptor("tgt-seq-length", short="tgtl", default=512), 130 | IntArgumentDescriptor("min-tgt-length", short="tgtm", default=0), 131 | FloatArgumentDescriptor("length-penalty", short="lp", default=0.8), 132 | IntArgumentDescriptor("num-gpus-per-node", short="ngpn", default=1), 133 | ]) 134 | 135 | content = open('src/glm/scripts/ds_finetune_seq2seq.sh').read() 136 | with open('glm/scripts/ds_finetune_seq2seq.sh', 'w') as f: 137 | f.write(content.replace('NUM_GPUS_PER_WORKER=8', f'NUM_GPUS_PER_WORKER={args.num_gpus_per_node}')) 138 | 139 | if args.model_type == "auto": 140 | assert (args.model_checkpoint is not None) 141 | args.model_type = "model_blocklm_" + args.model_checkpoint.split('/')[-2].split('_')[0] 142 | if args.task.startswith("fewrel"): 143 | CreateFolder(f"../data/{args.task}/") 144 | CMD(f"cp -rf ../data/FewRelEpisodic/* ../data/{args.task}/") 145 | _, shot, way = args.task.split('_') 146 | args.way, args.shot = way.strip('way'), shot.strip('shot') 147 | 148 | CreateFolder("glm/runs") 149 | CreateFolder("glm/config_tasks") 150 | with open(f"glm/config_tasks/{args.model_type}.sh", "w") as f: 151 | f.write( 152 | f""" 153 | MODEL_TYPE={'-'.join(args.model_type.split('_')[1:])} 154 | MODEL_ARGS={supported_model_args[args.model_type].format(args.model_checkpoint)} 155 | """ 156 | ) 157 | 158 | with open(f"glm/config_tasks/seq_cnndm_org.sh", "w") as f: 159 | f.write( 160 | """ 161 | EXPERIMENT_NAME=${MODEL_TYPE}-cnndm_org 162 | TASK_NAME=cnn_dm_original 163 | DATA_PATH=\"${DATA_ROOT}/${TASK_DATASET}\" 164 | """ + 165 | f""" 166 | TRAIN_ARGS={train_args.format(args.task_epochs, args.length_penalty, args.src_seq_length, args.tgt_seq_length, args.min_tgt_length, args.num_beams)} 167 | COMMON_ARGS={common_args.format(args.task_epochs, args.length_penalty, args.src_seq_length, args.tgt_seq_length, args.min_tgt_length, args.num_beams)} 168 | TASK_ARGS={task_args.format(args.task_epochs, args.length_penalty, args.src_seq_length, args.tgt_seq_length, args.min_tgt_length, args.num_beams)} 169 | """ 170 | ) 171 | 172 | CreateFolder("scripts") 173 | with open(f"scripts/{args.task}.sh", "w") as f: 174 | run_command = f"bash scripts/ds_finetune_seq2seq{'_multiserver' if args.multi_server else ''}.sh config_tasks/{args.model_type}.sh config_tasks/seq_cnndm_org.sh {args.task}" 175 | if args.task.startswith("fewrel"): 176 | commands = "\n".join( 177 | f""" 178 | cd ../../data/{args.task}/ 179 | bash set_n_way_k_shot.sh {args.way}_{args.shot}_{i} 180 | cd ../../deepstruct/glm/ 181 | {run_command} 182 | cd ../../data/{args.task}/ 183 | bash reset_n_way_k_shot.sh {args.way}_{args.shot}_{i} 184 | cd ../../deepstruct/glm/ 185 | """ for i in range(10) 186 | ) 187 | else: 188 | commands = run_command 189 | f.write( 190 | f""" 191 | source PATH.sh 192 | cd ./dataset_processing 193 | python3 run.py {args.task} -mode {args.mode} --data_only 194 | cd ../ 195 | cd ./glm 196 | {commands} 197 | cd ../ 198 | """ 199 | ) 200 | if args.task in ['oie_nyt', 'oie_oie2016', 'oie_penn', 'oie_web']: 201 | f.write(f"python oie-eval/supervised-oie-benchmark/evaluate_oie.py -task {args.task.split('_')[-1]}\n") 202 | if args.task in ['conll12_coref', 'ace2005event_argument']: 203 | f.write( 204 | f"cd ./dataset_processing/ && python run.py {args.task} -mode multi --evaluate_only && cd ../") 205 | 206 | CreateFolder("logs") 207 | handler = subprocess.Popen(f"bash scripts/{args.task}.sh", 208 | shell=True, 209 | stdout=subprocess.PIPE) 210 | with open(f'logs/{args.task}_{FORMATTED_TIME()}.log', 'w') as file: 211 | for line in iter(lambda: handler.stdout.readline(), b""): 212 | output = line.decode(sys.stdout.encoding) 213 | if 'Iteration' in output or ('F1' in output and 'overall' not in output) or '###' in output: 214 | sys.stdout.write(output) 215 | file.write(output) 216 | 217 | with open(f"glm/runs/latest_run") as f: 218 | exp_name = f.readline().strip() 219 | 220 | CMD(f"cp -f glm/runs/{exp_name}/test.jsonl.hyps ../data/{args.task}/") 221 | 222 | CMD(f"python glm/evaluate.py -task {args.task}" + ["", " --zero-shot"][args.zero_shot]) 223 | -------------------------------------------------------------------------------- /src/run_scripts/ace2005_jer.sh: -------------------------------------------------------------------------------- 1 | bash ./data_scripts/ace2005_joint_er.sh $1 2 | bash ./tasks/mt/ace2005_ent.sh 3 | bash ./tasks/mt/ace2005_rel.sh 4 | -------------------------------------------------------------------------------- /src/run_scripts/ace2005event.sh: -------------------------------------------------------------------------------- 1 | bash ./data_scripts/ace2005event.sh $1 2 | bash ./tasks/mt/ace2005_trigger.sh 3 | cp ../data/ace2005event_trigger/test.target ../data/ace2005event_trigger/test.jsonl.refs 4 | bash ./tasks/mt/ace2005_argument.sh 5 | -------------------------------------------------------------------------------- /src/run_scripts/ade.sh: -------------------------------------------------------------------------------- 1 | bash ./data_scripts/ade.sh 2 | bash ./tasks/mt/ade0_ent.sh 3 | bash ./tasks/mt/ade0_rel.sh 4 | -------------------------------------------------------------------------------- /src/run_scripts/atis.sh: -------------------------------------------------------------------------------- 1 | bash ./data_scripts/atis.sh 2 | bash ./tasks/mt/atis.sh 3 | -------------------------------------------------------------------------------- /src/run_scripts/conll04.sh: -------------------------------------------------------------------------------- 1 | bash ./data_scripts/conll04.sh 2 | bash ./tasks/mt/conll04_ent.sh 3 | bash ./tasks/mt/conll04_rel.sh 4 | -------------------------------------------------------------------------------- /src/run_scripts/conll05_srl_brown.sh: -------------------------------------------------------------------------------- 1 | bash ./data_scripts/conll05_srl.sh $1 2 | bash ./tasks/mt/conll05_brown.sh 3 | -------------------------------------------------------------------------------- /src/run_scripts/conll05_srl_wsj.sh: -------------------------------------------------------------------------------- 1 | bash ./data_scripts/conll05_srl.sh $1 2 | bash ./tasks/mt/conll05_wsj.sh 3 | -------------------------------------------------------------------------------- /src/run_scripts/conll12_srl.sh: -------------------------------------------------------------------------------- 1 | bash ./data_scripts/conll12_srl.sh $1 2 | bash ./tasks/mt/conll12.sh 3 | -------------------------------------------------------------------------------- /src/run_scripts/multi_woz.sh: -------------------------------------------------------------------------------- 1 | bash ./data_scripts/multi_woz.sh 2 | bash ./tasks/mt/multi_woz.sh 3 | -------------------------------------------------------------------------------- /src/run_scripts/nyt.sh: -------------------------------------------------------------------------------- 1 | bash ./data_scripts/nyt.sh 2 | bash ./tasks/mt/nyt_ent.sh 3 | bash ./tasks/mt/nyt_rel.sh 4 | -------------------------------------------------------------------------------- /src/run_scripts/snips.sh: -------------------------------------------------------------------------------- 1 | bash ./data_scripts/snips.sh 2 | bash ./tasks/mt/snips.sh 3 | -------------------------------------------------------------------------------- /src/tasks/mt/ace2005_argument.sh: -------------------------------------------------------------------------------- 1 | python3 manager.py --model-type model_blocklm_10B \ 2 | --model-checkpoint ../../ckpt/MP/10B_1/ \ 3 | --task ace2005event_argument \ 4 | --task-epochs 0 \ 5 | --length-penalty 0.3 6 | -------------------------------------------------------------------------------- /src/tasks/mt/ace2005_ent.sh: -------------------------------------------------------------------------------- 1 | python3 manager.py --model-type model_blocklm_10B \ 2 | --model-checkpoint ../../ckpt/MP/10B_1/ \ 3 | --task ace2005_joint_er \ 4 | --task-epochs 0 \ 5 | --length-penalty 0.8 6 | -------------------------------------------------------------------------------- /src/tasks/mt/ace2005_rel.sh: -------------------------------------------------------------------------------- 1 | python3 manager.py --model-type model_blocklm_10B \ 2 | --model-checkpoint ../../ckpt/MP/10B_1/ \ 3 | --task ace2005_joint_er_re \ 4 | --task-epochs 0 \ 5 | --length-penalty 0.3 6 | -------------------------------------------------------------------------------- /src/tasks/mt/ace2005_trigger.sh: -------------------------------------------------------------------------------- 1 | python3 manager.py --model-type model_blocklm_10B \ 2 | --model-checkpoint ../../ckpt/MP/10B_1/ \ 3 | --task ace2005event_trigger \ 4 | --task-epochs 0 \ 5 | --length-penalty 0.3 6 | -------------------------------------------------------------------------------- /src/tasks/mt/ade0_ent.sh: -------------------------------------------------------------------------------- 1 | python3 manager.py --model-type model_blocklm_10B \ 2 | --model-checkpoint ../../ckpt/MP/10B/ \ 3 | --task ade0 \ 4 | --task-epochs 0 \ 5 | --length-penalty 0.8 6 | -------------------------------------------------------------------------------- /src/tasks/mt/ade0_rel.sh: -------------------------------------------------------------------------------- 1 | python3 manager.py --model-type model_blocklm_10B \ 2 | --model-checkpoint ../../ckpt/MP/10B/ \ 3 | --task ade_re0 \ 4 | --task-epochs 0 \ 5 | --length-penalty 0.3 6 | -------------------------------------------------------------------------------- /src/tasks/mt/atis.sh: -------------------------------------------------------------------------------- 1 | python3 manager.py --model-type model_blocklm_10B \ 2 | --model-checkpoint ../../ckpt/MP/10B/ \ 3 | --task atis \ 4 | --task-epochs 0 \ 5 | --length-penalty 0.8 6 | -------------------------------------------------------------------------------- /src/tasks/mt/conll04_ent.sh: -------------------------------------------------------------------------------- 1 | python3 manager.py --model-type model_blocklm_10B \ 2 | --model-checkpoint ../../ckpt/MP/10B_1/ \ 3 | --task conll04 \ 4 | --task-epochs 0 \ 5 | --length-penalty 0.8 6 | -------------------------------------------------------------------------------- /src/tasks/mt/conll04_rel.sh: -------------------------------------------------------------------------------- 1 | python3 manager.py --model-type model_blocklm_10B \ 2 | --model-checkpoint ../../ckpt/MP/10B_1/ \ 3 | --task conll04_re \ 4 | --task-epochs 0 \ 5 | --length-penalty 0.3 6 | -------------------------------------------------------------------------------- /src/tasks/mt/conll05_brown.sh: -------------------------------------------------------------------------------- 1 | python3 manager.py --model-type model_blocklm_10B \ 2 | --model-checkpoint ../../ckpt/MP/10B_1/ \ 3 | --task conll05_srl_brown \ 4 | --task-epochs 0 \ 5 | --length-penalty 1.0 6 | -------------------------------------------------------------------------------- /src/tasks/mt/conll05_wsj.sh: -------------------------------------------------------------------------------- 1 | python3 manager.py --model-type model_blocklm_10B \ 2 | --model-checkpoint ../../ckpt/MP/10B_1/ \ 3 | --task conll05_srl_wsj \ 4 | --task-epochs 0 \ 5 | --length-penalty 1.0 6 | -------------------------------------------------------------------------------- /src/tasks/mt/conll12.sh: -------------------------------------------------------------------------------- 1 | python3 manager.py --model-type model_blocklm_10B \ 2 | --model-checkpoint ../../ckpt/MP/10B_1/ \ 3 | --task conll12_srl \ 4 | --task-epochs 0 \ 5 | --length-penalty 1.0 6 | -------------------------------------------------------------------------------- /src/tasks/mt/multi_woz.sh: -------------------------------------------------------------------------------- 1 | python3 manager.py --model-type model_blocklm_10B \ 2 | --model-checkpoint ../../ckpt/MP/10B/ \ 3 | --task multi_woz \ 4 | --task-epochs 0 \ 5 | --length-penalty 0.8 -------------------------------------------------------------------------------- /src/tasks/mt/nyt_ent.sh: -------------------------------------------------------------------------------- 1 | python3 manager.py --model-type model_blocklm_10B \ 2 | --model-checkpoint ../../ckpt/MP/10B_1/ \ 3 | --task nyt \ 4 | --task-epochs 0 \ 5 | --length-penalty 0.8 6 | -------------------------------------------------------------------------------- /src/tasks/mt/nyt_rel.sh: -------------------------------------------------------------------------------- 1 | python3 manager.py --model-type model_blocklm_10B \ 2 | --model-checkpoint ../../ckpt/MP/10B_1/ \ 3 | --task nyt_re \ 4 | --task-epochs 0 \ 5 | --length-penalty 0.3 6 | -------------------------------------------------------------------------------- /src/tasks/mt/snips.sh: -------------------------------------------------------------------------------- 1 | python3 manager.py --model-type model_blocklm_10B \ 2 | --model-checkpoint ../../ckpt/MP/10B/ \ 3 | --task snips \ 4 | --task-epochs 0 \ 5 | --length-penalty 0.8 6 | --------------------------------------------------------------------------------